未验证 提交 3ea7d577 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix compilation error by using thrust (#54364)

* Fix compilation error by using thrust

* fix
上级 33e7a46d
......@@ -14,7 +14,9 @@
#pragma once
#if defined(__NVCC__)
#include <thrust/device_vector.h>
#endif
#include "gflags/gflags.h"
#include "glog/logging.h"
......@@ -282,6 +284,7 @@ struct CUBlas<phi::dtype::float16> {
ldc));
}
#if defined(__NVCC__)
static void GEMM_BATCH(phi::GPUContext *dev_ctx,
cublasOperation_t transa,
cublasOperation_t transb,
......@@ -342,6 +345,7 @@ struct CUBlas<phi::dtype::float16> {
"cublasGemmBatchedEx is not supported on cuda <= 7.5"));
#endif
}
#endif
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
......@@ -1754,6 +1758,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
}
}
#if defined(__NVCC__)
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
......@@ -1973,6 +1978,7 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 11000
}
#endif
template <>
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册