From 3206af9db2bcd5e050b542df37bb4f091b24804c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 6 Aug 2021 16:20:12 +0800 Subject: [PATCH] perf(functional/matmul): reimplement matmul with subgraph GitOrigin-RevId: 456b2a51d35852152c46d31548dac96f977d5b41 --- .../python/megengine/functional/math.py | 318 +++++++++++++----- 1 file changed, 243 insertions(+), 75 deletions(-) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 0488e257d..eadc6ec85 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -8,17 +8,21 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import math +from functools import lru_cache from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply, dtype_promotion +from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin +from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt from ..core.ops.special import Const from ..core.tensor import amp -from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar +from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph +from ..jit import exclude_from_trace from ..tensor import Tensor from .debug_param import get_execution_strategy -from .elemwise import clip +from .elemwise import clip, minimum from .tensor import broadcast_to, concat, expand_dims, squeeze __all__ = [ @@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor: return result +@lru_cache(maxsize=None) +def _get_extentedMatrixMulOp( + device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, +): + @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3) + def extentedMatrixMulOp(inputs, f, c): + assert len(inputs) == 2 + inp1, inp2 = inputs + _dim1, _dim2 = dim1, dim2 + + def build_shape_head(shape, idx=-1): + # shape[:idx] + return f( + builtin.Subtensor(items=[[0, False, True, False, False]]), + shape, + c(idx, "int32"), + ) + + def build_shape_tail(shape, idx=-1): + # shape[idx:] + return f( + builtin.Subtensor(items=[[0, True, False, False, False]]), + shape, + c(idx, "int32"), + ) + + remove_row, remove_col = False, False + if _dim1 == 1: + _dim1 = 2 + remove_row = True + if _dim2 == 1: + _dim2 = 2 + remove_col = True + + if remove_row: + inp1 = f(builtin.AddAxis(axis=[0,]), inp1) + if remove_col: + inp2 = f(builtin.AddAxis(axis=[1,]), inp2) + + shape1 = f(GetVarShape(), inp1) + shape2 = f(GetVarShape(), inp2) + if _dim1 > 2: + inp1 = f( + builtin.Reshape(), + inp1, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), + build_shape_tail(shape1), + ), + ) + if _dim2 > 2: + inp2 = f( + builtin.Reshape(), + inp2, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), + build_shape_tail(shape2), + ), + ) + op = builtin.MatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=strategy, + ) + result = f(op, inp1, inp2) + result_shape = f(GetVarShape(), result) + if _dim1 > 2: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape1), + build_shape_tail(result_shape), + ), + ) + if _dim2 > 2: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape2), + build_shape_tail(result_shape), + ), + ) + maxdim = _dim1 if _dim1 > _dim2 else _dim2 + if remove_row: + result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) + if remove_col: + result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) + return (result,), (True,) + + return extentedMatrixMulOp + + +@lru_cache(maxsize=None) +def _get_extentedBatchedMatrixMulOp( + device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, +): + @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3) + def extentedBatchedMatrixMulOp(inputs, f, c): + assert len(inputs) == 2 + inp1, inp2 = inputs + _dim1, _dim2 = dim1, dim2 + + def build_shape_head(shape, idx=-2): + # shape[:idx] + return f( + builtin.Subtensor(items=[[0, False, True, False, False]]), + shape, + c(idx, "int32"), + ) + + def build_shape_tail(shape, idx=-2): + # shape[idx:] + return f( + builtin.Subtensor(items=[[0, True, False, False, False]]), + shape, + c(idx, "int32"), + ) + + remove_row, remove_col = False, False + if _dim1 == 1: + _dim1 = 2 + remove_row = True + if _dim2 == 1: + _dim2 = 2 + remove_col = True + + if remove_row: + inp1 = f(builtin.AddAxis(axis=[0,]), inp1) + if remove_col: + inp2 = f(builtin.AddAxis(axis=[1,]), inp2) + shape1 = f(GetVarShape(), inp1) + shape2 = f(GetVarShape(), inp2) + maxdim = _dim1 if _dim1 > _dim2 else _dim2 + if _dim1 > _dim2: + # broadcast + shape2 = f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] + shape2, + ) + inp2 = f(builtin.Broadcast(), inp2, shape2) + batch_shape = build_shape_head(shape1) + if _dim2 > _dim1: + # broadcast + shape1 = f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] + shape1, + ) + inp1 = f(builtin.Broadcast(), inp1, shape1) + batch_shape = build_shape_head(shape2) + if _dim1 == _dim2: + batch_shape = build_shape_head(shape1) + + # compress inputs to 3d + if maxdim > 3: + inp1 = f( + builtin.Reshape(), + inp1, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), batch_shape), + build_shape_tail(shape1), + ), + ) + inp2 = f( + builtin.Reshape(), + inp2, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), batch_shape), + build_shape_tail(shape2), + ), + ) + op = builtin.BatchedMatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=strategy, + ) + result = f(op, inp1, inp2) + + if maxdim > 3: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + batch_shape, + build_shape_tail(f(GetVarShape(), result)), + ), + ) + if remove_row: + result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) + if remove_col: + result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) + return (result,), (True,) + + return extentedBatchedMatrixMulOp + + def matmul( inp1: Tensor, inp2: Tensor, @@ -822,85 +1036,39 @@ def matmul( if inp2.dtype != dtype: inp2 = inp2.astype(dtype) - remove_row, remove_col = False, False dim1, dim2 = inp1.ndim, inp2.ndim - # handle dim=1 cases, dot and matrix-vector multiplication - if dim1 == 1 and dim2 == 1: - return dot(inp1, inp2) - # the underlying matmul op requires input dims to be at least 2 - if dim1 == 1: - inp1 = expand_dims(inp1, 0) - dim1 = 2 - remove_row = True - if dim2 == 1: - inp2 = expand_dims(inp2, 1) - dim2 = 2 - remove_col = True - - batch_shape = None - shape1 = inp1.shape - shape2 = inp2.shape - + assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 - if dim1 >= 3 or dim2 >= 3: - if use_symbolic_shape(): - if dim1 > dim2: - shape2 = concat([shape1[:-2], shape2[-2:]]) - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = concat([shape2[:-2], shape1[-2:]]) - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - (inp1,) = apply( - builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) - ) - (inp2,) = apply( - builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) - ) - else: - if dim1 > dim2: - shape2 = shape1[:-2] + shape2[-2:] - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = shape2[:-2] + shape1[-2:] - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) - inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) - - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, + if dim1 == 1 and dim2 == 1: # dispatch to Dot + return dot(inp1, inp2) + elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul + extentedMatrixMulOp = _get_extentedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, strategy=get_execution_strategy(), ) - else: - op = builtin.MatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, + (result,) = apply(extentedMatrixMulOp, inp1, inp2) + return result + else: # dispath to BatchedMatrixMul + extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, strategy=get_execution_strategy(), ) - - (result,) = apply(op, inp1, inp2) - if maxdim > 3: - if use_symbolic_shape(): - (result,) = apply( - builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) - ) - else: - result = result.reshape(batch_shape + result.shape[-2:]) - if remove_row: - result = squeeze(result, axis=-2) - if remove_col: - result = squeeze(result, axis=-1) - return result + (result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) + return result def dot(inp1: Tensor, inp2: Tensor) -> Tensor: -- GitLab