提交 c670058a 编写于 作者: B Bob Zhu 提交者: Tao Luo

add support of matmul with multiple head even different width and height (#19708)

* add support of matmul with multiple head even different width and height

Original matmul with multiple head supports only the mat_a.width == mat_b.height,
in that case, mat_b will be horizontally split. In this patch, we extend the
support when mat_a.width != mat_b.height but mat_a.width/head_number == mat_b.height,
in this case, mab_b will be vertically split.

One example is A is [3, 8], B is [2, 16], head_number is 4. In this
case, A will be split as [3, 2], B will be (vertically) split as
[2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16]

test=develop

* add support of matmul with multiple head even different width and height

Original matmul with multiple head supports only the mat_a.width == mat_b.height,
in that case, mat_b will be horizontally split. In this patch, we extend the
support when mat_a.width != mat_b.height but mat_a.width/head_number == mat_b.height,
in this case, mab_b will be vertically split.

One example is A is [3, 8], B is [2, 16], head_number is 4. In this
case, A will be split as [3, 2], B will be (vertically) split as
[2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16]

test=develop

* refactor the code of matmul with multiple head even different width and height

test=develop
上级 6884dc80
...@@ -125,7 +125,8 @@ class Blas { ...@@ -125,7 +125,8 @@ class Blas {
const MatDescriptor& dim_a, const MatDescriptor& dim_a,
const framework::Tensor& mat_b, const framework::Tensor& mat_b,
const MatDescriptor& dim_b, T alpha, int head_number, const MatDescriptor& dim_b, T alpha, int head_number,
framework::Tensor* mat_out, T beta) const; framework::Tensor* mat_out, T beta,
bool mat_y_split_vertical) const;
#endif #endif
#endif #endif
...@@ -194,9 +195,10 @@ class Blas { ...@@ -194,9 +195,10 @@ class Blas {
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T> template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int M, int N, int K, T alpha, const T* A, const T* B, int W1, int H1, int W2, int H2, T alpha, const T* A,
T beta, T* C, int batchCount, int64_t strideA, const T* B, T beta, T* C, int batchCount,
int64_t strideB, int64_t head_number) const; int64_t strideA, int64_t strideB,
int64_t head_number, bool split_b_vertical) const;
#endif #endif
template <typename T> template <typename T>
......
...@@ -583,33 +583,64 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM( ...@@ -583,33 +583,64 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead( void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount, int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB, int64_t head_number) const { int64_t strideA, int64_t strideB, int64_t head_number,
int lda = (transA == CblasNoTrans) ? K : M; bool split_b_vertical) const {
int ldb = (transB == CblasNoTrans) ? N : K; int lda = (transA == CblasNoTrans) ? W1 : H1;
int ldc = N * head_number; int ldb = (transB == CblasNoTrans) ? W2 : H2;
int sub_width = K / head_number;
auto a_array = std::vector<const T *>(batchCount); auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount); auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount); auto c_array = std::vector<T *>(batchCount);
for (int i = 0; i < head_number; i++) { if (split_b_vertical) {
int sub_matA_offset = (transA == CblasNoTrans) ? i * (K / head_number) int ldc = W2;
: i * (K / head_number) * M; int sub_width = W2 / head_number;
int sub_matB_offset = (transB == CblasNoTrans) ? i * (K / head_number) * N
: i * (K / head_number); for (int i = 0; i < head_number; i++) {
int sub_matC_offset = i * N; int sub_matA_offset = (transA == CblasNoTrans)
for (int k = 0; k < batchCount; ++k) { ? i * (W1 / head_number)
a_array[k] = &A[k * strideA] + sub_matA_offset; : i * (W1 / head_number) * H1;
b_array[k] = &B[k * strideB] + sub_matB_offset; int sub_matB_offset = (transB == CblasNoTrans)
c_array[k] = &C[k * M * head_number * N] + sub_matC_offset; ? i * (W2 / head_number)
: i * (W2 / head_number) * H2;
int sub_matC_offset = i * W2 / head_number;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width,
&H2, &alpha, a_array.data(), &lda, b_array.data(),
&ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
} }
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &sub_width, } else {
&alpha, a_array.data(), &lda, b_array.data(), &ldb, PADDLE_ENFORCE_EQ(W1, H2);
&beta, c_array.data(), &ldc, 1 /* group_count */, int ldc = W2 * head_number;
&batchCount); int sub_width = W1 / head_number;
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans)
? i * (W1 / head_number)
: i * (W1 / head_number) * H1;
int sub_matB_offset = (transB == CblasNoTrans)
? i * (W1 / head_number) * W2
: i * (W1 / head_number);
int sub_matC_offset = i * W2;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2,
&sub_width, &alpha, a_array.data(), &lda,
b_array.data(), &ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
}
} }
} }
#endif #endif
...@@ -690,51 +721,86 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -690,51 +721,86 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
* When user calls this API, the multiplication of two big matrixes is split * When user calls this API, the multiplication of two big matrixes is split
* into multiplication of several (head_number_) small matrixes. e.g. if Mat A * into multiplication of several (head_number_) small matrixes. e.g. if Mat A
* is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as
* 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be 4 matrix of * 4, Mat A will be splitted as 4 matrix of [3, 6] and Mat B will be
* [6, 4]. The result of final matrix will be 4 matrix of [3, 4], i.e. [3, 16]. * (horizontally) splitted as 4 matrix of [6, 4]. The result of final matrix
* * will be 4 matrix of [3, 4], i.e. [3, 16].
* Another example is A is [3, 8], B is [2, 16], head_number is 4. In this
* case, A will be splitted as [3, 2], B will be (vertically) splitted as
* [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16]
*/ */
template <typename DeviceContext> template <typename DeviceContext>
template <typename T> template <typename T>
void Blas<DeviceContext>::MatMulWithHead( void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
const framework::Tensor &mat_a, const MatDescriptor &dim_a, const MatDescriptor &dim_a,
const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, const framework::Tensor &mat_b,
int head_number, framework::Tensor *mat_out, T beta) const { const MatDescriptor &dim_b, T alpha,
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); int head_number,
framework::Tensor *mat_out, T beta,
bool mat_b_split_vertical) const {
PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0); PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0);
PADDLE_ENFORCE_GE(head_number, 1); PADDLE_ENFORCE_GE(head_number, 1);
PADDLE_ENFORCE_LE(head_number, dim_a.width_); PADDLE_ENFORCE_LE(head_number, dim_a.width_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (mat_b_split_vertical) {
PADDLE_ENFORCE_EQ(dim_b.height_, dim_a.width_ / head_number);
PADDLE_ENFORCE_EQ(dim_b.width_ % head_number, 0);
}
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_;
int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_;
int sub_matA_offset;
int sub_matB_offset;
int sub_matC_offset;
int sub_mat_M = dim_a.height_;
int sub_mat_N;
int sub_mat_K;
int ldc;
for (int i = 0; i < head_number; i++) { for (int i = 0; i < head_number; i++) {
int sub_matA_offset = sub_matA_offset = dim_a.trans_
dim_a.trans_ ? i * (dim_a.width_ / head_number) * dim_a.height_ ? i * (dim_a.width_ / head_number) * dim_a.height_
: i * (dim_a.width_ / head_number); : i * (dim_a.width_ / head_number);
int sub_matB_offset = if (mat_b_split_vertical) {
dim_b.trans_ ? i * (dim_b.height_ / head_number) sub_matB_offset = dim_b.trans_
: i * (dim_b.height_ / head_number) * dim_b.width_; ? i * (dim_b.width_ / head_number) * dim_b.height_
int sub_matC_offset = i * dim_b.width_; : i * (dim_b.width_ / head_number);
int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; sub_matC_offset = i * dim_b.width_ / head_number;
int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_;
int ldc = head_number * dim_b.width_; sub_mat_N = dim_b.width_ / head_number;
sub_mat_K = dim_b.height_;
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_ / head_number, alpha, ldc = dim_b.width_;
mat_a.data<T>() + sub_matA_offset, lda, } else {
sub_matB_offset =
dim_b.trans_ ? i * (dim_b.height_ / head_number)
: i * (dim_b.height_ / head_number) * dim_b.width_;
sub_matC_offset = i * dim_b.width_;
sub_mat_N = dim_b.width_;
sub_mat_K = dim_a.width_ / head_number;
ldc = head_number * dim_b.width_;
}
this->template GEMM<T>(transA, transB, sub_mat_M, sub_mat_N, sub_mat_K,
alpha, mat_a.data<T>() + sub_matA_offset, lda,
mat_b.data<T>() + sub_matB_offset, ldb, beta, mat_b.data<T>() + sub_matB_offset, ldb, beta,
mat_out->data<T>() + sub_matC_offset, ldc); mat_out->data<T>() + sub_matC_offset, ldc);
} }
} else { } else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || PADDLE_ENFORCE_EQ((dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0),
true);
this->template BatchedGEMMWithHead<T>( this->template BatchedGEMMWithHead<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, transA, transB, dim_a.width_, dim_a.height_, dim_b.width_,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(), dim_b.height_, alpha, mat_a.data<T>(), mat_b.data<T>(), beta,
mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_, head_number); dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical);
} }
} }
#endif #endif
......
...@@ -63,11 +63,13 @@ class MatMulKernel : public framework::OpKernel<T> { ...@@ -63,11 +63,13 @@ class MatMulKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");
if (1 == head_number) { bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_);
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0));
} else { if (head_number > 1) {
blas.MatMulWithHead(x, mat_dim_a, y, mat_dim_b, scale, head_number, out, blas.MatMulWithHead(x, mat_dim_a, y, mat_dim_b, scale, head_number, out,
T(0)); T(0), split_vertical_y);
} else {
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0));
} }
#else #else
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0));
...@@ -300,19 +302,22 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -300,19 +302,22 @@ class MatMulOp : public framework::OperatorWithKernel {
math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0, math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0,
context->Attrs().Get<bool>("transpose_Y")); context->Attrs().Get<bool>("transpose_Y"));
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
if (context->IsRuntime()) { if (context->IsRuntime()) {
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
} }
std::vector<int64_t> dim_out; std::vector<int64_t> dim_out;
int64_t dim_out_y = mat_dim_y.width_;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context->Attrs().Get<int>("head_number"); int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GE(head_number, 1); bool split_vertical_y = (mat_dim_x.width_ != mat_dim_y.height_);
PADDLE_ENFORCE_LE(head_number, mat_dim_x.width_); PADDLE_ENFORCE_LE(head_number, mat_dim_x.width_);
int64_t dim_out_y = head_number * mat_dim_y.width_;
if (!split_vertical_y && head_number > 0) {
dim_out_y = head_number * mat_dim_y.width_;
}
#else #else
int64_t dim_out_y = mat_dim_y.width_; PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
#endif #endif
if (mat_dim_x.batch_size_ != 0) { if (mat_dim_x.batch_size_ != 0) {
......
...@@ -148,11 +148,147 @@ def inject_test_multiple_head(dim_x, dim_y, trans_x, trans_y, head_number): ...@@ -148,11 +148,147 @@ def inject_test_multiple_head(dim_x, dim_y, trans_x, trans_y, head_number):
}) })
def matmul_head2(X, Y, head_number=1):
x = []
y = []
z = []
sub_x_width = X.shape[-1] // head_number
sub_y_width = Y.shape[-1] // head_number
assert (sub_x_width == Y.shape[-2]
), "Error: incompatible head number or matrix size!"
if np.ndim(X) == 2:
for i in range(0, head_number):
x.append(X[:, i * sub_x_width:i * sub_x_width + sub_x_width])
y.append(Y[:, i * sub_y_width:i * sub_y_width + sub_y_width])
for i in range(0, head_number):
z.append(np.matmul(x[i], y[i]))
Z = np.concatenate((z), axis=1)
elif np.ndim(X) == 3:
for i in range(0, head_number):
x.append(X[:, :, i * sub_x_width:i * sub_x_width + sub_x_width])
y.append(Y[:, :, i * sub_y_width:i * sub_y_width + sub_y_width])
for i in range(0, head_number):
z.append(np.matmul(x[i], y[i]))
Z = np.concatenate((z), axis=2)
else:
assert False, "ERROR: Not supported dimension!"
return Z
def reference_matmul_mul_head2(X,
Y,
head_number=1,
transpose_X=False,
transpose_Y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
X = transpose_mat(X)
if transpose_Y:
Y = transpose_mat(Y)
Out = matmul_head2(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
def generate_compatible_shapes_mul_head2(dim_X, dim_Y, transpose_X,
transpose_Y):
BATCH_SIZE = 2
# Assume head number H is 4. We need make sure K1/H = M2
M1 = 3
K1 = 8
M2 = 2
K2 = 16
if dim_X >= 2:
if transpose_X:
shape_X = [K1, M1]
else:
shape_X = [M1, K1]
if dim_X == 3:
shape_X = [BATCH_SIZE] + shape_X
if dim_Y >= 2:
if transpose_Y:
shape_Y = [K2, M2]
else:
shape_Y = [M2, K2]
if dim_Y == 3:
shape_Y = [BATCH_SIZE] + shape_Y
return shape_X, shape_Y
# Generator for multiple head, case 2 when width of X is not same as height of Y
class GeneratorMulHead2(object):
def setUp(self):
self.op_type = "matmul"
X = np.zeros(self.shape_X)
Y = np.zeros(self.shape_Y)
if len(self.shape_X) == 2:
X = np.arange(
0, self.shape_X[-1] * self.shape_X[-2],
dtype=np.float32).reshape(self.shape_X)
Y = np.arange(
0, self.shape_Y[-1] * self.shape_Y[-2],
dtype=np.float32).reshape(self.shape_Y)
else:
for i in range(0, len(self.shape_X) - 1):
X[i, :, :] = np.arange(
0, self.shape_X[-1] * self.shape_X[-2],
dtype=np.float32).reshape(list(self.shape_X)[-2:])
Y[i, :, :] = np.arange(
0, self.shape_Y[-1] * self.shape_Y[-2],
dtype=np.float32).reshape(list(self.shape_Y)[-2:])
Out = reference_matmul_mul_head2(X, Y, 4, self.transpose_X,
self.transpose_Y)
self.inputs = {'X': X, 'Y': Y}
self.attrs = {
'transpose_X': self.transpose_X,
'transpose_Y': self.transpose_Y,
'head_number': self.head_number
}
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output(atol=1e-3)
def inject_test_multiple_head2(dim_x, dim_y, trans_x, trans_y, head_number):
test_name = (
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}_head2_{}'.format(
dim_x, dim_y, trans_x, trans_y, head_number))
shape_x, shape_y = generate_compatible_shapes_mul_head2(dim_x, dim_y,
trans_x, trans_y)
globals()[test_name] = type(test_name, (GeneratorMulHead2, OpTest), {
'shape_X': shape_x,
'shape_Y': shape_y,
'transpose_X': trans_x,
'transpose_Y': trans_y,
'head_number': head_number
})
#test case for multiple head #test case for multiple head
for dim in (2, 3): for dim in (2, 3):
for transose_x in (False, True): for transose_x in (False, True):
for transose_y in (False, True): for transose_y in (False, True):
inject_test_multiple_head(dim, dim, transose_x, transose_y, 4) inject_test_multiple_head(dim, dim, transose_x, transose_y, 4)
#test case for multiple head when X.width != Y.height
for dim in (2, 3):
for transose_x in (False, True):
for transose_y in (False, True):
inject_test_multiple_head2(dim, dim, transose_x, transose_y, 4)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册