diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 5e4bfd9be248d681e95397895db055104234c02f..84eea97da9f510cdae08992db24bc374b7017cb0 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -14,7 +14,9 @@ #pragma once +#if defined(__NVCC__) #include +#endif #include "gflags/gflags.h" #include "glog/logging.h" @@ -282,6 +284,7 @@ struct CUBlas { ldc)); } +#if defined(__NVCC__) static void GEMM_BATCH(phi::GPUContext *dev_ctx, cublasOperation_t transa, cublasOperation_t transb, @@ -342,6 +345,7 @@ struct CUBlas { "cublasGemmBatchedEx is not supported on cuda <= 7.5")); #endif } +#endif static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, @@ -1754,6 +1758,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } } +#if defined(__NVCC__) template <> template <> inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, @@ -1973,6 +1978,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 11000 } +#endif template <> template