未验证 提交 3ea7d577 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix compilation error by using thrust (#54364)

* Fix compilation error by using thrust

* fix
上级 33e7a46d
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#if defined(__NVCC__)
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#endif
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
...@@ -282,6 +284,7 @@ struct CUBlas<phi::dtype::float16> { ...@@ -282,6 +284,7 @@ struct CUBlas<phi::dtype::float16> {
ldc)); ldc));
} }
#if defined(__NVCC__)
static void GEMM_BATCH(phi::GPUContext *dev_ctx, static void GEMM_BATCH(phi::GPUContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,
...@@ -342,6 +345,7 @@ struct CUBlas<phi::dtype::float16> { ...@@ -342,6 +345,7 @@ struct CUBlas<phi::dtype::float16> {
"cublasGemmBatchedEx is not supported on cuda <= 7.5")); "cublasGemmBatchedEx is not supported on cuda <= 7.5"));
#endif #endif
} }
#endif
static void GEMM_STRIDED_BATCH(cublasHandle_t handle, static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
...@@ -1754,6 +1758,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1754,6 +1758,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
} }
} }
#if defined(__NVCC__)
template <> template <>
template <> template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
...@@ -1973,6 +1978,7 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1973,6 +1978,7 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 11000 #endif // CUDA_VERSION >= 11000
} }
#endif
template <> template <>
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册