未验证 提交 4478389c 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix bmm_kernel (#45530)

上级 4a25b60d
......@@ -1128,6 +1128,108 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
});
}
// note(wangran16): unknown bug. parameters dislocation when calling
// GEMM_STRIDED_BATCH<float> and GEMM_STRIDED_BATCH<double>
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
float alpha,
const float *A,
const float *B,
float beta,
float *C,
int batchCount,
int64_t strideA,
int64_t strideB) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
const int64_t strideC = M * N;
context_.CublasCall([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::rocblas_sgemm_strided_batched(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
strideB,
A,
lda,
strideA,
&beta,
C,
ldc,
strideC,
batchCount));
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
double alpha,
const double *A,
const double *B,
double beta,
double *C,
int batchCount,
int64_t strideA,
int64_t strideB) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
const int64_t strideC = M * N;
context_.CublasCall([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::rocblas_dgemm_strided_batched(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
strideB,
A,
lda,
strideA,
&beta,
C,
ldc,
strideC,
batchCount));
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册