提交 d25395fc 编写于 作者: S sneaxiy

remove tensor core lock

test=develop
上级 9c2cbfb8
......@@ -62,27 +62,17 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto cublas_call = [&]() {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc));
VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
#else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
}
};
......@@ -170,32 +160,23 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
#else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
algo));
#else
cublas_call();
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
}
};
......@@ -353,22 +334,18 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
};
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
&beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
} else {
#endif // CUDA_VERSION >= 9010
......
......@@ -247,6 +247,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(new cublasHandle_t());
PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get()));
PADDLE_ENFORCE(
dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_));
PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_,
CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
}
......@@ -307,6 +319,10 @@ CUDADeviceContext::~CUDADeviceContext() {
Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cublas_tensor_core_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_));
cublas_tensor_core_handle_.reset();
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -339,6 +355,15 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const {
return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_
: cublas_handle_;
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
}
......
......@@ -209,39 +209,6 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(CUDAPlace place);
......@@ -265,6 +232,13 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is
* not supported, return the same handle as cublas_handle(). */
cublasHandle_t possible_cublas_tensor_core_handle() const;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
......@@ -294,18 +268,6 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private:
CUDAPlace place_;
......@@ -314,6 +276,7 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<cublasHandle_t> cublas_tensor_core_handle_;
int compute_capability_;
int runtime_version_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册