提交 220eef60 编写于 作者: B Bob Zhu 提交者: Tao Luo

Extend Matmul to support matrix multiplication with multiple heads (#18570)

* extend matmul op to support multiple head multiplication

With the support of multiple head, the multiplication of two big matrixes is
split into multiplication of several (head_number) small matrixes. e.g. if
Mat A is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number
as 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be 4 matrix of
[6, 4]. The result of final matrix will be 4 matrix of [3, 4], i.e. [3, 16].
上级 075e1cf7
......@@ -112,6 +112,15 @@ class Blas {
template <typename T>
void GEMM_FREE(T* data) const;
#if !defined(PADDLE_WITH_CUDA)
template <typename T>
void MatMulWithHead(const framework::Tensor& mat_a,
const MatDescriptor& dim_a,
const framework::Tensor& mat_b,
const MatDescriptor& dim_b, T alpha, int head_number,
framework::Tensor* mat_out, T beta) const;
#endif
#endif
template <typename T>
......@@ -176,6 +185,14 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int M, int N, int K, T alpha, const T* A, const T* B,
T beta, T* C, int batchCount, int64_t strideA,
int64_t strideB, int64_t head_number) const;
#endif
template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
......@@ -221,6 +238,13 @@ class BlasT : private Blas<DeviceContext> {
void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...);
}
#if !defined(PADDLE_WITH_CUDA)
template <typename... ARGS>
void MatMulWithHead(ARGS... args) const {
Base()->template MatMulWithHead<T>(args...);
}
#endif
#endif
template <typename... ARGS>
......
......@@ -567,6 +567,41 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB, int64_t head_number) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N * head_number;
int sub_width = K / head_number;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans) ? i * (K / head_number)
: i * (K / head_number) * M;
int sub_matB_offset = (transB == CblasNoTrans) ? i * (K / head_number) * N
: i * (K / head_number);
int sub_matC_offset = i * N;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * M * head_number * N] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &sub_width,
&alpha, a_array.data(), &lda, b_array.data(), &ldb,
&beta, c_array.data(), &ldc, 1 /* group_count */,
&batchCount);
}
}
#endif
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
......@@ -627,6 +662,67 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a.stride_, dim_b.stride_);
}
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
/*
* Multiple two matrixes with multiple heads
*
* A new parameter, i.e head_number is added compared to normal MatMul.
* The head_number describes the number of heads a matrix is vertically
* split.
*
* When user calls this API, the multiplication of two big matrixes is split
* into multiplication of several (head_number_) small matrixes. e.g. if Mat A
* is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as
* 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be 4 matrix of
* [6, 4]. The result of final matrix will be 4 matrix of [3, 4], i.e. [3, 16].
*
*/
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMulWithHead(
const framework::Tensor &mat_a, const MatDescriptor &dim_a,
const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha,
int head_number, framework::Tensor *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0);
PADDLE_ENFORCE_GE(head_number, 1);
PADDLE_ENFORCE_LE(head_number, dim_a.width_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
for (int i = 0; i < head_number; i++) {
int sub_matA_offset =
dim_a.trans_ ? i * (dim_a.width_ / head_number) * dim_a.height_
: i * (dim_a.width_ / head_number);
int sub_matB_offset =
dim_b.trans_ ? i * (dim_b.height_ / head_number)
: i * (dim_b.height_ / head_number) * dim_b.width_;
int sub_matC_offset = i * dim_b.width_;
int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_;
int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_;
int ldc = head_number * dim_b.width_;
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_ / head_number, alpha,
mat_a.data<T>() + sub_matA_offset, lda,
mat_b.data<T>() + sub_matB_offset, ldb, beta,
mat_out->data<T>() + sub_matC_offset, ldc);
}
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMMWithHead<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_, head_number);
}
}
#endif
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
......
......@@ -60,7 +60,18 @@ class MatMulKernel : public framework::OpKernel<T> {
auto mat_dim_b = math::CreateMatrixDescriptor(
ColumnMatrixFromVector(y.dims()), 0, context.Attr<bool>("transpose_Y"));
auto scale = static_cast<T>(context.Attr<float>("alpha"));
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context.Attr<int>("head_number");
if (1 == head_number) {
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0));
} else {
blas.MatMulWithHead(x, mat_dim_a, y, mat_dim_b, scale, head_number, out,
T(0));
}
#else
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0));
#endif
}
};
......@@ -295,16 +306,25 @@ class MatMulOp : public framework::OperatorWithKernel {
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
}
std::vector<int64_t> dim_out;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GE(head_number, 1);
PADDLE_ENFORCE_LE(head_number, mat_dim_x.width_);
int64_t dim_out_y = head_number * mat_dim_y.width_;
#else
int64_t dim_out_y = mat_dim_y.width_;
#endif
if (mat_dim_x.batch_size_ != 0) {
dim_out = framework::vectorize(dim_x);
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
dim_out[dim_out.size() - 1] = dim_out_y;
} else if (mat_dim_y.batch_size_ != 0) {
dim_out = framework::vectorize(dim_y);
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
dim_out[dim_out.size() - 1] = dim_out_y;
} else {
dim_out = {mat_dim_x.height_, mat_dim_y.width_};
dim_out = {mat_dim_x.height_, dim_out_y};
}
if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) {
......@@ -339,6 +359,10 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC")
.SetDefault(false);
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
#endif
AddComment(R"DOC(
MatMul Operator.
......@@ -360,6 +384,9 @@ Examples without transpose:
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
- X: [B, ..., M, K], Y: [B, ..., K, N] => Out: [B, ..., M, N]
Example of matrix multiplication with head_number of H
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, H * N]
The behavior is designed to be similar to the `numpy.matmul` function.
The differences are:
- When the rank of the input data is less than or equal to 3, it
......@@ -367,6 +394,9 @@ The differences are:
- When the rank of the input is greater than 3, the rank of X and
Y must be equal, and the first `rank - 2` dimensions must be equal.
- We add `transpose_X` and `transpose_Y` flags.
- We add `head_number` attribute, which is used to multiple two matrixes head
by head, and eventually concatenates the output of several (head_number)
small matrixes multiplication.
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input `X`.
......
......@@ -31,7 +31,6 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
endif()
LIST(REMOVE_ITEM TEST_OPS test_launch)
if (NOT ${WITH_GPU})
......@@ -72,6 +71,11 @@ if(NOT WITH_MKLML)
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
endif()
if(WITH_GPU OR NOT WITH_MKLML)
# matmul with multiple heads need MKL support
LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head)
endif()
function(py_test_modules TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
def generate_compatible_shapes_mul_head(dim_X, dim_Y, transpose_X, transpose_Y):
BATCH_SIZE = 2
M = 3
N = 4
K = 24
if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y):
K = 1
if dim_X == 1:
if transpose_X:
shape_X = [M]
else:
shape_X = [K]
if dim_Y == 1:
if transpose_Y:
shape_Y = [N]
else:
shape_Y = [K]
if dim_X >= 2:
if transpose_X:
shape_X = [K, M]
else:
shape_X = [M, K]
if dim_X == 3:
shape_X = [BATCH_SIZE] + shape_X
if dim_Y >= 2:
if transpose_Y:
shape_Y = [N, K]
else:
shape_Y = [K, N]
if dim_Y == 3:
shape_Y = [BATCH_SIZE] + shape_Y
return shape_X, shape_Y
def matmul_head(X, Y, head_number=1):
x = []
y = []
z = []
sub_x_width = X.shape[-1] // head_number
sub_y_height = Y.shape[-2] // head_number
if np.ndim(X) == 2:
for i in range(0, head_number):
x.append(X[:, i * sub_x_width:i * sub_x_width + sub_x_width])
y.append(Y[i * sub_y_height:i * sub_y_height + sub_y_height, :])
for i in range(0, head_number):
z.append(np.matmul(x[i], y[i]))
Z = np.concatenate((z), axis=1)
elif np.ndim(X) == 3:
for i in range(0, head_number):
x.append(X[:, :, i * sub_x_width:i * sub_x_width + sub_x_width])
y.append(Y[:, i * sub_y_height:i * sub_y_height + sub_y_height, :])
for i in range(0, head_number):
z.append(np.matmul(x[i], y[i]))
Z = np.concatenate((z), axis=2)
else:
print("ERROR: Not supported dimension")
return Z
def transpose_mat(X):
if X.ndim >= 2:
dim = np.arange(X.ndim)
dim[[-1, -2]] = dim[[-2, -1]]
X = np.transpose(X, tuple(dim))
return X
def reference_matmul_mul_head(X,
Y,
head_number=1,
transpose_X=False,
transpose_Y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
X = transpose_mat(X)
if transpose_Y:
Y = transpose_mat(Y)
Out = matmul_head(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
# Generator for multiple head
class GeneratorMulHead(object):
def setUp(self):
self.op_type = "matmul"
X = np.random.random(self.shape_X).astype("float32")
Y = np.random.random(self.shape_Y).astype("float32")
Out = reference_matmul_mul_head(X, Y, 4, self.transpose_X,
self.transpose_Y)
self.inputs = {'X': X, 'Y': Y}
self.attrs = {
'transpose_X': self.transpose_X,
'transpose_Y': self.transpose_Y,
'head_number': self.head_number
}
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output(atol=1e-3)
def inject_test_multiple_head(dim_x, dim_y, trans_x, trans_y, head_number):
test_name = (
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}_head_{}'.format(
dim_x, dim_y, trans_x, trans_y, head_number))
shape_x, shape_y = generate_compatible_shapes_mul_head(dim_x, dim_y,
trans_x, trans_y)
globals()[test_name] = type(test_name, (GeneratorMulHead, OpTest), {
'shape_X': shape_x,
'shape_Y': shape_y,
'transpose_X': trans_x,
'transpose_Y': trans_y,
'head_number': head_number
})
#test case for multiple head
for dim in (2, 3):
for transose_x in (False, True):
for transose_y in (False, True):
inject_test_multiple_head(dim, dim, transose_x, transose_y, 4)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册