diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h index e322fba39a481b50e0a49a4390d0c993de77d450..b35882df1e5ed3ac10536c1b8c3e0addcfe8ff99 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h @@ -1128,6 +1128,108 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, }); } +// note(wangran16): unknown bug. parameters dislocation when calling +// GEMM_STRIDED_BATCH and GEMM_STRIDED_BATCH +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const float *A, + const float *B, + float beta, + float *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sgemm_strided_batched(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount)); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + double alpha, + const double *A, + const double *B, + double beta, + double *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dgemm_strided_batched(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount)); + }); +} + template <> template <> inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA,