Created by: mkliegl
This closes issues #4683 (closed) and #4696 (closed) .
The MatMul operator is used to perform (batched) matrix multiplication
over the last two dimensions of the input tensors X
and Y
.
If a transpose flag is specified, the last two dimensions of the
tensor are transposed. If the tensor is rank-1 of shape [D], then
for X
it is treated as [1, D] in nontransposed form and as [D, 1]
in transposed form, whereas for Y
it is the opposite: It is treated
as [D, 1] in nontransposed form and as [1, D] in transposed form.
Examples without transpose:
- X: [K], Y: [K] => Out: [1]
- X: [K], Y: [K, N] => Out: [N]
- X: [B, M, K], Y: [K] => Out: [B, M]
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
The behavior is designed to be similar to the numpy.matmul
function.
The differences are:
- Currently only rank 1 to rank 3 input tensors are supported.
- We add
transpose_X
andtranspose_Y
flags, similar to BLAS routines.
If there is interest, I could add support for rank 4 and higher tensors in a future PR. Essentially this should just involve adding some code to reshape to rank 3 and then undoing the reshape.