提交 6f06e6cd 编写于 作者: S sneaxiy

Merge remote origin

test=develop
...@@ -62,27 +62,17 @@ struct CUBlas<float> { ...@@ -62,27 +62,17 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) { cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that // Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack // appears in a lambda-expression, I can not use template parameter pack
// here. // here.
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
lda, B, Btype, ldb, beta, C, Ctype, ldc)); alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
#else #else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -170,32 +160,23 @@ struct CUBlas<platform::float16> { ...@@ -170,32 +160,23 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc, cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) { cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable(); bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) { if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
#else algo));
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else #else
cublas_call(); PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif #endif
} }
}; };
...@@ -353,22 +334,18 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -353,22 +334,18 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) { if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = context_.tensor_core_available();
bool use_tensor_op_math = platform::TensorCoreAvailable(); if (use_tensor_op_math) {
if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; }
} VLOG(5) << "use_tensor_op_math: "
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
<< (use_tensor_op_math ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
};
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
......
...@@ -247,6 +247,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -247,6 +247,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(new cublasHandle_t());
PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get()));
PADDLE_ENFORCE(
dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_));
PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_,
CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
} }
...@@ -307,6 +319,10 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -307,6 +319,10 @@ CUDADeviceContext::~CUDADeviceContext() {
Wait(); Wait();
WaitStreamCallback(); WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cublas_tensor_core_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_));
cublas_tensor_core_handle_.reset();
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -339,6 +355,15 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { ...@@ -339,6 +355,15 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_; return cublas_handle_;
} }
cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const {
return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_
: cublas_handle_;
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle(); return cudnn_holder_->cudnn_handle();
} }
......
...@@ -209,39 +209,6 @@ class CudnnWorkspaceHandle { ...@@ -209,39 +209,6 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_; std::unique_ptr<std::lock_guard<std::mutex>> guard_;
}; };
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -265,6 +232,13 @@ class CUDADeviceContext : public DeviceContext { ...@@ -265,6 +232,13 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cublas handle in the device context. */ /*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const; cublasHandle_t cublas_handle() const;
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is
* not supported, return the same handle as cublas_handle(). */
cublasHandle_t possible_cublas_tensor_core_handle() const;
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
...@@ -294,18 +268,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -294,18 +268,6 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); } void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -314,6 +276,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -314,6 +276,7 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<CudnnHolder> cudnn_holder_; std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
std::unique_ptr<cublasHandle_t> cublas_tensor_core_handle_;
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册