提交 1e0fb127 编写于 作者: M Megvii Engine Team

perf(mge/functional): reduce matmul python overhead

GitOrigin-RevId: 738d0da10ee7efdf3bd3e3030f10ac9c8aa16a5c
上级 72619bb4
......@@ -345,9 +345,10 @@ class ArrayMethodMixin(abc.ABC):
@property
def ndim(self):
shape = self.shape
# XXX: assume ndim is not changed during trace
if isinstance(shape, self.__class__):
shape = shape.numpy()
# XXX: assume ndim is not changed during trace
ndim = shape.__wrapped__.shape[0]
return ndim
return len(shape)
@property
......
......@@ -10,6 +10,7 @@
from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.builtin import BatchNorm
......@@ -1015,23 +1016,39 @@ def matmul(
remove_col = True
batch_shape = None
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device)
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device)
shape1 = inp1.shape
shape2 = inp2.shape
maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if dim1 == dim2:
assert (
shape1[:-2] == shape2[:-2]
).min(), "operands could not be broadcasted together."
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]]))
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]]))
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
......@@ -1048,8 +1065,13 @@ def matmul(
)
(result,) = apply(op, inp1, inp2)
if batch_shape is not None:
result = result.reshape(concat([batch_shape, result.shape[-2:]]))
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册