未验证 提交 92913027 编写于 作者: K Kexin Zhao 提交者: GitHub

fix unused var error (#9908)

上级 47609ab2
......@@ -268,6 +268,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -289,7 +290,6 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
......@@ -304,6 +304,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -315,7 +316,6 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
......@@ -330,6 +330,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -341,7 +342,6 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册