未验证 提交 00b9e9a1 编写于 作者: C chengduo 提交者: GitHub

Refine cublas to support CUBLAS_TENSOR_OP_MATH (#13929)

* refine cublase
test=develop

* code refine

* refine cublas

* add GEMME_EX

* add enable_cublas_tensor_op_math doc and add cublasCall
test=develop

* fix CublasCall for cuda version
test=develop

* fix error
test=develop

* fix GEMM_EX to be compatible with gcc 4.8
test=develop

* add GEMM_EX
test=develop

* to compatiable with gcc4.8
test=develop
上级 dd6fd4c7
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/gpu_info.h"
DECLARE_bool(enable_cublas_tensor_op_math);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,11 +45,44 @@ struct CUBlas<float> { ...@@ -42,11 +45,44 @@ struct CUBlas<float> {
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
#else #else
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const float *alpha, const void *A,
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc));
#else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -69,13 +105,18 @@ struct CUBlas<double> { ...@@ -69,13 +105,18 @@ struct CUBlas<double> {
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
#else #else
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif
} }
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
PADDLE_THROW("Currently there are not cublasDgemmEx.");
}
}; };
template <> template <>
...@@ -96,14 +137,16 @@ struct CUBlas<platform::float16> { ...@@ -96,14 +137,16 @@ struct CUBlas<platform::float16> {
reinterpret_cast<__half *>(C), ldc)); reinterpret_cast<__half *>(C), ldc));
} }
static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa, static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transa,
const float16 *alpha, const float16 *A, int lda, cublasOperation_t transb, int m, int n, int k,
long long int strideA, const float16 *B, // NOLINT const float16 *alpha, const float16 *A,
int ldb, long long int strideB, // NOLINT int lda, long long int strideA, // NOLINT
const float16 *beta, float16 *C, int ldc, const float16 *B, // NOLINT
long long int strideC, // NOLINT int ldb, long long int strideB, // NOLINT
int batchCount) { const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
...@@ -114,6 +157,45 @@ struct CUBlas<platform::float16> { ...@@ -114,6 +157,45 @@ struct CUBlas<platform::float16> {
ldc, strideC, batchCount)); ldc, strideC, batchCount));
#else #else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const void *alpha, const void *A,
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
#else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -133,8 +215,21 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -133,8 +215,21 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, #if CUDA_VERSION >= 8000
B, ldb, A, lda, &beta, C, N); if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
CUDA_R_32F, N);
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, N);
#if CUDA_VERSION >= 8000
}
#endif // CUDA_VERSION >= 8000
} }
template <> template <>
...@@ -157,30 +252,18 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -157,30 +252,18 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53, PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53"); "cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha); float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta); float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 9000
if (context_.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
}
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta // cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated // input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs. // using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B, CUBlas<platform::float16>::GEMM_EX(
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A,
CUDA_R_32F, algo)); CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
...@@ -199,8 +282,38 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -199,8 +282,38 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc); #if CUDA_VERSION >= 8000
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
CUDA_R_32F, ldc);
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, ldc);
#if CUDA_VERSION >= 8000
}
#endif // CUDA_VERSION >= 8000
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
platform::float16 beta, platform::float16 *C, int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
ldc);
} }
template <> template <>
...@@ -238,9 +351,34 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -238,9 +351,34 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N; const int64_t strideC = M * N;
CUBlas<T>::GEMM_BATCH(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, #if CUDA_VERSION >= 9010
&alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
strideC, batchCount); auto cublas_call = [&]() {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
};
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
} else {
#endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, strideB, A, lda,
strideA, &beta, C, ldc, strideC, batchCount);
#if CUDA_VERSION >= 9010
}
#endif // CUDA_VERSION >= 9010
} }
} // namespace math } // namespace math
......
...@@ -143,6 +143,39 @@ class CudnnWorkspaceHandle { ...@@ -143,6 +143,39 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_; std::unique_ptr<std::lock_guard<std::mutex>> guard_;
}; };
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -199,6 +232,18 @@ class CUDADeviceContext : public DeviceContext { ...@@ -199,6 +232,18 @@ class CUDADeviceContext : public DeviceContext {
callback_manager_->Wait(); callback_manager_->Wait();
} }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -220,6 +265,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -220,6 +265,8 @@ class CUDADeviceContext : public DeviceContext {
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_; mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_;
}; };
template <> template <>
......
...@@ -61,9 +61,6 @@ extern void *cublas_dso_handle; ...@@ -61,9 +61,6 @@ extern void *cublas_dso_handle;
extern DynLoad__##__name __name extern DynLoad__##__name __name
#endif #endif
#define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSaxpy_v2); \ __macro(cublasSaxpy_v2); \
__macro(cublasDaxpy_v2); \ __macro(cublasDaxpy_v2); \
...@@ -93,22 +90,23 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) ...@@ -93,22 +90,23 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
// APIs available after CUDA 8.0 // APIs available after CUDA 8.0
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx);
__macro(cublasGemmEx); \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched);
__macro(cublasSgemmStridedBatched); \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched);
__macro(cublasDgemmStridedBatched); \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched);
__macro(cublasCgemmStridedBatched); \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched);
__macro(cublasZgemmStridedBatched); \ DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched);
__macro(cublasHgemmStridedBatched);
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif #endif
// APIs available after CUDA 9.0 // APIs available after CUDA 9.0
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode); DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode);
#endif
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #if CUDA_VERSION >= 9010
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx);
#endif #endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
......
...@@ -26,6 +26,16 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.92, ...@@ -26,6 +26,16 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.92,
"additional trunks of the same size will be requested from gpu " "additional trunks of the same size will be requested from gpu "
"until the gpu has no memory left for another trunk."); "until the gpu has no memory left for another trunk.");
DEFINE_bool(
enable_cublas_tensor_op_math, false,
"The enable_cublas_tensor_op_math indicate whether to use Tensor Core, "
"but it may loss precision. Currently, There are two CUDA libraries that"
" use Tensor Cores, cuBLAS and cuDNN. cuBLAS uses Tensor Cores to speed up"
" GEMM computations(the matrices must be either half precision or single "
"precision); cuDNN uses Tensor Cores to speed up both convolutions(the "
"input and output must be half precision) and recurrent neural networks "
"(RNNs).");
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -64,6 +74,16 @@ int GetCUDADriverVersion(int id) { ...@@ -64,6 +74,16 @@ int GetCUDADriverVersion(int id) {
return driver_version; return driver_version;
} }
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
int device = GetCurrentDeviceId();
int driver_version = GetCUDAComputeCapability(device);
return driver_version >= 70;
#else
return false;
#endif
}
int GetCUDAMultiProcessors(int id) { int GetCUDAMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count; int count;
......
...@@ -35,6 +35,9 @@ int GetCUDARuntimeVersion(int id); ...@@ -35,6 +35,9 @@ int GetCUDARuntimeVersion(int id);
//! Get the driver version of the ith GPU //! Get the driver version of the ith GPU
int GetCUDADriverVersion(int id); int GetCUDADriverVersion(int id);
//! Wheter the current device support TensorCore
bool TensorCoreAvailable();
//! Get the MultiProcessors of the ith GPU. //! Get the MultiProcessors of the ith GPU.
int GetCUDAMultiProcessors(int i); int GetCUDAMultiProcessors(int i);
......
...@@ -133,7 +133,8 @@ def __bootstrap__(): ...@@ -133,7 +133,8 @@ def __bootstrap__():
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'conv_workspace_size_limit', 'cudnn_exhaustive_search' 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册