未验证 提交 8db3ff1f 编写于 作者: L lishicheng1996 提交者: GitHub

fix a bug caused by hipcc lambda value capture (#55612)

上级 017a6164
......@@ -1173,6 +1173,56 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
float16 alpha,
const float16 *A,
const float16 *B,
float16 beta,
float16 *C,
int batchCount,
int64_t strideA,
int64_t strideB) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
const int64_t strideC = M * N;
context_.CublasCall([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_hgemm_strided_batched(
handle,
cuTransB,
cuTransA,
N,
M,
K,
reinterpret_cast<const rocblas_half *>(&alpha),
reinterpret_cast<const rocblas_half *>(B),
ldb,
strideB,
reinterpret_cast<const rocblas_half *>(A),
lda,
strideA,
reinterpret_cast<const rocblas_half *>(&beta),
reinterpret_cast<rocblas_half *>(C),
ldc,
strideC,
batchCount));
});
}
// note(wangran16): unknown bug. parameters dislocation when calling
// GEMM_STRIDED_BATCH<float> and GEMM_STRIDED_BATCH<double>
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册