diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 0c1b61c14473f2a413ff98c1760e6b1dd77b21cf..a15dab935552c6e93fd9c0d9963985d6ea024f35 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -125,7 +125,8 @@ class Blas { const MatDescriptor& dim_a, const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, int head_number, - framework::Tensor* mat_out, T beta) const; + framework::Tensor* mat_out, T beta, + bool mat_y_split_vertical) const; #endif #endif @@ -194,9 +195,10 @@ class Blas { #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) template void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int M, int N, int K, T alpha, const T* A, const T* B, - T beta, T* C, int batchCount, int64_t strideA, - int64_t strideB, int64_t head_number) const; + int W1, int H1, int W2, int H2, T alpha, const T* A, + const T* B, T beta, T* C, int batchCount, + int64_t strideA, int64_t strideB, + int64_t head_number, bool split_b_vertical) const; #endif template diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 8bc1bd720cedb63a8568196acd021fc80a8b6671..e2620bcfd9298f38f887f8a5b35aa8efba6b7053 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -583,33 +583,64 @@ void Blas::BatchedGEMM( template <> template void Blas::BatchedGEMMWithHead( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB, int64_t head_number) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N * head_number; - int sub_width = K / head_number; + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, + int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, + int64_t strideA, int64_t strideB, int64_t head_number, + bool split_b_vertical) const { + int lda = (transA == CblasNoTrans) ? W1 : H1; + int ldb = (transB == CblasNoTrans) ? W2 : H2; auto a_array = std::vector(batchCount); auto b_array = std::vector(batchCount); auto c_array = std::vector(batchCount); - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) ? i * (K / head_number) - : i * (K / head_number) * M; - int sub_matB_offset = (transB == CblasNoTrans) ? i * (K / head_number) * N - : i * (K / head_number); - int sub_matC_offset = i * N; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * M * head_number * N] + sub_matC_offset; + if (split_b_vertical) { + int ldc = W2; + int sub_width = W2 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W2 / head_number) + : i * (W2 / head_number) * H2; + int sub_matC_offset = i * W2 / head_number; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, + &H2, &alpha, a_array.data(), &lda, b_array.data(), + &ldb, &beta, c_array.data(), &ldc, + 1 /* group_count */, &batchCount); } - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &sub_width, - &alpha, a_array.data(), &lda, b_array.data(), &ldb, - &beta, c_array.data(), &ldc, 1 /* group_count */, - &batchCount); + } else { + PADDLE_ENFORCE_EQ(W1, H2); + int ldc = W2 * head_number; + int sub_width = W1 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W1 / head_number) * W2 + : i * (W1 / head_number); + int sub_matC_offset = i * W2; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, + &sub_width, &alpha, a_array.data(), &lda, + b_array.data(), &ldb, &beta, c_array.data(), &ldc, + 1 /* group_count */, &batchCount); + } } } #endif @@ -690,51 +721,86 @@ void Blas::MatMul(const framework::Tensor &mat_a, * When user calls this API, the multiplication of two big matrixes is split * into multiplication of several (head_number_) small matrixes. e.g. if Mat A * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as - * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be 4 matrix of - * [6, 4]. The result of final matrix will be 4 matrix of [3, 4], i.e. [3, 16]. - * + * 4, Mat A will be splitted as 4 matrix of [3, 6] and Mat B will be + * (horizontally) splitted as 4 matrix of [6, 4]. The result of final matrix + * will be 4 matrix of [3, 4], i.e. [3, 16]. + * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this + * case, A will be splitted as [3, 2], B will be (vertically) splitted as + * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] */ template template -void Blas::MatMulWithHead( - const framework::Tensor &mat_a, const MatDescriptor &dim_a, - const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, - int head_number, framework::Tensor *mat_out, T beta) const { - PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); +void Blas::MatMulWithHead(const framework::Tensor &mat_a, + const MatDescriptor &dim_a, + const framework::Tensor &mat_b, + const MatDescriptor &dim_b, T alpha, + int head_number, + framework::Tensor *mat_out, T beta, + bool mat_b_split_vertical) const { PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0); PADDLE_ENFORCE_GE(head_number, 1); PADDLE_ENFORCE_LE(head_number, dim_a.width_); CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + if (mat_b_split_vertical) { + PADDLE_ENFORCE_EQ(dim_b.height_, dim_a.width_ / head_number); + PADDLE_ENFORCE_EQ(dim_b.width_ % head_number, 0); + } + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; + int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; + int sub_matA_offset; + int sub_matB_offset; + int sub_matC_offset; + int sub_mat_M = dim_a.height_; + int sub_mat_N; + int sub_mat_K; + int ldc; + for (int i = 0; i < head_number; i++) { - int sub_matA_offset = - dim_a.trans_ ? i * (dim_a.width_ / head_number) * dim_a.height_ - : i * (dim_a.width_ / head_number); - int sub_matB_offset = - dim_b.trans_ ? i * (dim_b.height_ / head_number) - : i * (dim_b.height_ / head_number) * dim_b.width_; - int sub_matC_offset = i * dim_b.width_; - int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; - int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; - int ldc = head_number * dim_b.width_; - - this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_ / head_number, alpha, - mat_a.data() + sub_matA_offset, lda, + sub_matA_offset = dim_a.trans_ + ? i * (dim_a.width_ / head_number) * dim_a.height_ + : i * (dim_a.width_ / head_number); + if (mat_b_split_vertical) { + sub_matB_offset = dim_b.trans_ + ? i * (dim_b.width_ / head_number) * dim_b.height_ + : i * (dim_b.width_ / head_number); + sub_matC_offset = i * dim_b.width_ / head_number; + + sub_mat_N = dim_b.width_ / head_number; + sub_mat_K = dim_b.height_; + + ldc = dim_b.width_; + } else { + sub_matB_offset = + dim_b.trans_ ? i * (dim_b.height_ / head_number) + : i * (dim_b.height_ / head_number) * dim_b.width_; + sub_matC_offset = i * dim_b.width_; + + sub_mat_N = dim_b.width_; + sub_mat_K = dim_a.width_ / head_number; + + ldc = head_number * dim_b.width_; + } + + this->template GEMM(transA, transB, sub_mat_M, sub_mat_N, sub_mat_K, + alpha, mat_a.data() + sub_matA_offset, lda, mat_b.data() + sub_matB_offset, ldb, beta, mat_out->data() + sub_matC_offset, ldc); } } else { - PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || - dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); + PADDLE_ENFORCE_EQ((dim_a.batch_size_ == dim_b.batch_size_ || + dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0), + true); this->template BatchedGEMMWithHead( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, - mat_a.data(), mat_b.data(), beta, mat_out->data(), + transA, transB, dim_a.width_, dim_a.height_, dim_b.width_, + dim_b.height_, alpha, mat_a.data(), mat_b.data(), beta, + mat_out->data(), dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_, head_number); + dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical); } } #endif diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index ce252dba65975643fec93e6354f7cafcec508e3e..eb43f43daf446a4b8ca872f89e4dcb18ff7323d8 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -63,11 +63,13 @@ class MatMulKernel : public framework::OpKernel { #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) int head_number = context.Attr("head_number"); - if (1 == head_number) { - blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); - } else { + bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); + + if (head_number > 1) { blas.MatMulWithHead(x, mat_dim_a, y, mat_dim_b, scale, head_number, out, - T(0)); + T(0), split_vertical_y); + } else { + blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); } #else blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); @@ -300,19 +302,22 @@ class MatMulOp : public framework::OperatorWithKernel { math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0, context->Attrs().Get("transpose_Y")); - PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); if (context->IsRuntime()) { PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); } std::vector dim_out; + int64_t dim_out_y = mat_dim_y.width_; #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) int head_number = context->Attrs().Get("head_number"); - PADDLE_ENFORCE_GE(head_number, 1); + bool split_vertical_y = (mat_dim_x.width_ != mat_dim_y.height_); PADDLE_ENFORCE_LE(head_number, mat_dim_x.width_); - int64_t dim_out_y = head_number * mat_dim_y.width_; + + if (!split_vertical_y && head_number > 0) { + dim_out_y = head_number * mat_dim_y.width_; + } #else - int64_t dim_out_y = mat_dim_y.width_; + PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); #endif if (mat_dim_x.batch_size_ != 0) { diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py b/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py index acc8cfd8f3994095e73532aa2eed253808b9417e..3cca8af2d4973ed84f891101511432a102d1b84f 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op_with_head.py @@ -148,11 +148,147 @@ def inject_test_multiple_head(dim_x, dim_y, trans_x, trans_y, head_number): }) +def matmul_head2(X, Y, head_number=1): + x = [] + y = [] + z = [] + sub_x_width = X.shape[-1] // head_number + sub_y_width = Y.shape[-1] // head_number + assert (sub_x_width == Y.shape[-2] + ), "Error: incompatible head number or matrix size!" + if np.ndim(X) == 2: + for i in range(0, head_number): + x.append(X[:, i * sub_x_width:i * sub_x_width + sub_x_width]) + y.append(Y[:, i * sub_y_width:i * sub_y_width + sub_y_width]) + for i in range(0, head_number): + z.append(np.matmul(x[i], y[i])) + Z = np.concatenate((z), axis=1) + + elif np.ndim(X) == 3: + for i in range(0, head_number): + x.append(X[:, :, i * sub_x_width:i * sub_x_width + sub_x_width]) + y.append(Y[:, :, i * sub_y_width:i * sub_y_width + sub_y_width]) + for i in range(0, head_number): + z.append(np.matmul(x[i], y[i])) + Z = np.concatenate((z), axis=2) + else: + assert False, "ERROR: Not supported dimension!" + return Z + + +def reference_matmul_mul_head2(X, + Y, + head_number=1, + transpose_X=False, + transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + X = transpose_mat(X) + if transpose_Y: + Y = transpose_mat(Y) + + Out = matmul_head2(X, Y, head_number) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float32") + return Out + + +def generate_compatible_shapes_mul_head2(dim_X, dim_Y, transpose_X, + transpose_Y): + BATCH_SIZE = 2 + # Assume head number H is 4. We need make sure K1/H = M2 + M1 = 3 + K1 = 8 + M2 = 2 + K2 = 16 + + if dim_X >= 2: + if transpose_X: + shape_X = [K1, M1] + else: + shape_X = [M1, K1] + if dim_X == 3: + shape_X = [BATCH_SIZE] + shape_X + if dim_Y >= 2: + if transpose_Y: + shape_Y = [K2, M2] + else: + shape_Y = [M2, K2] + if dim_Y == 3: + shape_Y = [BATCH_SIZE] + shape_Y + return shape_X, shape_Y + + +# Generator for multiple head, case 2 when width of X is not same as height of Y +class GeneratorMulHead2(object): + def setUp(self): + self.op_type = "matmul" + + X = np.zeros(self.shape_X) + Y = np.zeros(self.shape_Y) + if len(self.shape_X) == 2: + X = np.arange( + 0, self.shape_X[-1] * self.shape_X[-2], + dtype=np.float32).reshape(self.shape_X) + Y = np.arange( + 0, self.shape_Y[-1] * self.shape_Y[-2], + dtype=np.float32).reshape(self.shape_Y) + else: + for i in range(0, len(self.shape_X) - 1): + X[i, :, :] = np.arange( + 0, self.shape_X[-1] * self.shape_X[-2], + dtype=np.float32).reshape(list(self.shape_X)[-2:]) + Y[i, :, :] = np.arange( + 0, self.shape_Y[-1] * self.shape_Y[-2], + dtype=np.float32).reshape(list(self.shape_Y)[-2:]) + + Out = reference_matmul_mul_head2(X, Y, 4, self.transpose_X, + self.transpose_Y) + + self.inputs = {'X': X, 'Y': Y} + self.attrs = { + 'transpose_X': self.transpose_X, + 'transpose_Y': self.transpose_Y, + 'head_number': self.head_number + } + self.outputs = {'Out': Out} + + def test_check_output(self): + self.check_output(atol=1e-3) + + +def inject_test_multiple_head2(dim_x, dim_y, trans_x, trans_y, head_number): + test_name = ( + 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}_head2_{}'.format( + dim_x, dim_y, trans_x, trans_y, head_number)) + shape_x, shape_y = generate_compatible_shapes_mul_head2(dim_x, dim_y, + trans_x, trans_y) + globals()[test_name] = type(test_name, (GeneratorMulHead2, OpTest), { + 'shape_X': shape_x, + 'shape_Y': shape_y, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, + 'head_number': head_number + }) + + #test case for multiple head for dim in (2, 3): for transose_x in (False, True): for transose_y in (False, True): inject_test_multiple_head(dim, dim, transose_x, transose_y, 4) +#test case for multiple head when X.width != Y.height +for dim in (2, 3): + for transose_x in (False, True): + for transose_y in (False, True): + inject_test_multiple_head2(dim, dim, transose_x, transose_y, 4) + if __name__ == "__main__": unittest.main()