提交 d0a8a1e9 编写于 作者: S sneaxiy

remove_op_handle_lock

test=develop
上级 6f06e6cd
...@@ -68,9 +68,11 @@ struct CUBlas<float> { ...@@ -68,9 +68,11 @@ struct CUBlas<float> {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)); beta, C, Ctype, ldc));
});
#else #else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif #endif
...@@ -171,10 +173,11 @@ struct CUBlas<platform::float16> { ...@@ -171,10 +173,11 @@ struct CUBlas<platform::float16> {
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, beta, C, Ctype, ldc, computeType, algo));
algo)); });
#else #else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif #endif
...@@ -204,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -204,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N); CUDA_R_32F, N);
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
&alpha, B, ldb, A, lda, &beta, C, N); lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -247,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -247,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda, context_.CublasCall([&](cublasHandle_t handle) {
&h_beta, h_C, N); CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
...@@ -273,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -273,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, context_.CublasCall([&](cublasHandle_t handle) {
&alpha, B, ldb, A, lda, &beta, C, ldc); CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -292,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -292,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, A, lda, &beta, C, CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
ldc); B, ldb, A, lda, &beta, C, ldc);
});
} }
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const { T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
} }
template <> template <>
...@@ -311,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, ...@@ -311,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const { T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, context_.CublasCall([&](cublasHandle_t handle) {
&beta, C, 1); CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
} }
template <> template <>
...@@ -342,16 +355,20 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -342,16 +355,20 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M, handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
&beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); strideC, batchCount, CUDA_R_32F, algo));
});
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, strideB, A, lda, CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
strideA, &beta, C, ldc, strideC, batchCount); B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
...@@ -245,17 +245,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -245,17 +245,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) { if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(new cublasHandle_t()); cublas_tensor_core_handle_.reset(
PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get())); new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
PADDLE_ENFORCE(
dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_));
PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_,
CUBLAS_TENSOR_OP_MATH));
#endif #endif
} }
...@@ -318,11 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -318,11 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback(); WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); cublas_handle_.reset();
if (cublas_tensor_core_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_));
cublas_tensor_core_handle_.reset(); cublas_tensor_core_handle_.reset();
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -351,15 +343,6 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -351,15 +343,6 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const {
return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_
: cublas_handle_;
}
bool CUDADeviceContext::tensor_core_available() const { bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr; return cublas_tensor_core_handle_ != nullptr;
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -229,15 +230,25 @@ class CUDADeviceContext : public DeviceContext { ...@@ -229,15 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
/*! \brief Return cublas handle in the device context. */ /*! \brief Call cublas function safely. */
cublasHandle_t cublas_handle() const; template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */ /*! \brief Check whether tensor core is supported */
bool tensor_core_available() const; bool tensor_core_available() const;
/*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is /*! \brief Call cublas function with Tensor Core safely. If
* not supported, return the same handle as cublas_handle(). */ Tensor Core is not available, use DEFAULT_MATH instead. */
cublasHandle_t possible_cublas_tensor_core_handle() const; template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
...@@ -256,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -256,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -275,8 +285,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -275,8 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_; std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<cublasHandle_t> cublas_tensor_core_handle_; std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
...@@ -284,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -284,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_; int multi_process_;
int max_threads_per_mp_; int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe // StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
}; };
template <> template <>
......
...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册