未验证 提交 92913027 编写于 作者: K Kexin Zhao 提交者: GitHub

fix unused var error (#9908)

上级 47609ab2
...@@ -268,6 +268,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -268,6 +268,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta, const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) { float16* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -289,7 +290,6 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -289,7 +290,6 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53"); "cublas Hgemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
...@@ -304,6 +304,7 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -304,6 +304,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -315,7 +316,6 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -315,7 +316,6 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
...@@ -330,6 +330,7 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -330,6 +330,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int strideA, const int strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -341,7 +342,6 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -341,7 +342,6 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册