提交 1fa143ce 编写于 作者: M Megvii Engine Team

refactor(mge/functional): matmul supports symbolic shape, batched mv multiply

GitOrigin-RevId: c4d8cf3306cd833828eca0fc7372397cbf2cc36f
上级 d47cf332
...@@ -23,7 +23,7 @@ from ..tensor import Tensor ...@@ -23,7 +23,7 @@ from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, prod, sum
from .tensor import ( from .tensor import (
broadcast_to, broadcast_to,
concat, concat,
...@@ -972,38 +972,42 @@ def matmul( ...@@ -972,38 +972,42 @@ def matmul(
[28. 40.]] [28. 40.]]
""" """
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2) inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2) return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
shp = None if dim1 == 1:
if dim1 > 3 or dim2 > 3: inp1 = expand_dims(inp1, 0)
shape1, shape2 = list(inp1.shape), list(inp2.shape) dim1 = 2
if dim1 != dim2: remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device)
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device)
if dim1 >= 3 or dim2 >= 3:
if dim1 == dim2:
assert (
shape1[:-2] == shape2[:-2]
).min(), "operands could not be broadcasted together."
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2: if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1 shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1) inp1 = broadcast_to(inp1, shape1)
else: batch_shape = shape1[:-2]
shape2 = shape1[: dim1 - dim2] + shape2 # compress inputs to 3d
inp2 = broadcast_to(inp2, shape2) inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]]))
reshaped_batch_size = 1 inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]]))
for i in shape1[:-2]:
reshaped_batch_size *= i
inp1 = inp1.reshape(*([reshaped_batch_size] + shape1[-2:]))
inp2 = inp2.reshape(*([reshaped_batch_size] + shape2[-2:]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
)
shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3:
if dim2 < 3:
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3:
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul( op = builtin.BatchedMatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
...@@ -1011,12 +1015,6 @@ def matmul( ...@@ -1011,12 +1015,6 @@ def matmul(
format=format, format=format,
) )
else: else:
if dim1 == 1:
shp = (inp2.shape[1],)
inp1 = expand_dims(inp1, 0)
if dim2 == 1:
shp = (inp1.shape[0],)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
...@@ -1025,8 +1023,12 @@ def matmul( ...@@ -1025,8 +1023,12 @@ def matmul(
) )
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
if shp is not None: if batch_shape is not None:
result = result.reshape(shp) result = result.reshape(concat([batch_shape, result.shape[-2:]]))
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result return result
......
...@@ -77,24 +77,41 @@ def test_matmul(): ...@@ -77,24 +77,41 @@ def test_matmul():
opr_test(cases, F.matmul, ref_fn=np.matmul) opr_test(cases, F.matmul, ref_fn=np.matmul)
batch_size = 10 batch_size = 10
shape1 = (batch_size, 2, 3) shape1 = (2,)
shape2 = (batch_size, 3, 4) shape2 = (batch_size, 2, 3)
shape3 = (batch_size, 10, 4, 5) shape3 = (batch_size, 3, 4)
shape4 = (batch_size, 10, 4, 2)
shape5 = (batch_size, 10, 2, 4)
data1 = np.random.random(shape1).astype("float32") data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32") data2 = np.random.random(shape2).astype("float32")
data3 = np.random.random(shape3).astype("float32") data3 = np.random.random(shape3).astype("float32")
data4 = np.random.random(shape4).astype("float32")
data5 = np.random.random(shape5).astype("float32")
cases = [{"input": [data1, data2]}, {"input": [data2, data3]}] cases = [
for i in range(0, batch_size): {"input": [data1, data2]},
{"input": [data2, data3]},
{"input": [data3, data4]},
{"input": [data4, data5]},
]
for _ in range(0, batch_size):
opr_test(
cases, F.matmul, ref_fn=np.matmul,
)
def compare_fn(x, y): opr_test(
x.numpy()[i, ...] == y [{"input": [data1, data4]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
transpose_b=True,
)
opr_test( opr_test(
cases, [{"input": [data3, data2]}],
F.matmul, F.matmul,
compare_fn=compare_fn, ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]), transpose_a=True,
transpose_b=True,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册