提交 d0a8a1e9 编写于 作者: S sneaxiy

remove_op_handle_lock

test=develop
上级 6f06e6cd
......@@ -68,9 +68,11 @@ struct CUBlas<float> {
#if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc));
});
#else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
......@@ -171,10 +173,11 @@ struct CUBlas<platform::float16> {
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
algo));
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
......@@ -204,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N);
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, N);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000
}
......@@ -247,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
&h_beta, h_C, N);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000
}
......@@ -273,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else {
#endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, A, lda, &beta, C, ldc);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000
}
......@@ -292,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
ldc);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
}
template <>
......@@ -311,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
&beta, C, 1);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
}
template <>
......@@ -342,16 +355,20 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
&beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
strideC, batchCount, CUDA_R_32F, algo));
});
} else {
#endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &alpha, B, ldb, strideB, A, lda,
strideA, &beta, C, ldc, strideC, batchCount);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
......@@ -245,17 +245,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(new cublasHandle_t());
PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get()));
PADDLE_ENFORCE(
dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_));
PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_,
CUBLAS_TENSOR_OP_MATH));
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
......@@ -318,11 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cublas_tensor_core_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_));
cublas_tensor_core_handle_.reset();
}
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -351,15 +343,6 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const {
return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_
: cublas_handle_;
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -229,15 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is
* not supported, return the same handle as cublas_handle(). */
cublasHandle_t possible_cublas_tensor_core_handle() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
......@@ -256,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
......@@ -275,8 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<cublasHandle_t> cublas_tensor_core_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_;
int runtime_version_;
......@@ -284,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_;
int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
template <>
......
......@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册