提交 ccb322d6 编写于 作者: Z Zeng Jinle 提交者: sneaxiy

merge develop

......@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto cublas_call = [&]() {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False");
<< (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc));
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc));
});
#else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif
}
};
......@@ -170,11 +162,10 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable();
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
......@@ -182,20 +173,13 @@ struct CUBlas<platform::float16> {
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif
}
};
......@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N);
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, N);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000
}
......@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
&h_beta, h_C, N);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000
}
......@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, ldc);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000
}
......@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
ldc);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
}
template <>
......@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
&beta, C, 1);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
}
template <>
......@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = platform::TensorCoreAvailable();
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
};
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
strideC, batchCount, CUDA_R_32F, algo));
});
} else {
#endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, strideB, A, lda,
strideA, &beta, C, ldc, strideC, batchCount);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
......@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
}
......@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(CUDAPlace place);
......@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
......@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
......@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private:
CUDAPlace place_;
......@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_;
int runtime_version_;
......@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_;
int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
template <>
......
......@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册