未验证 提交 6a4790c2 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15203 from sneaxiy/revert-15139-remove_op_handle_lock

Revert "Remove op handle lock"
...@@ -62,19 +62,27 @@ struct CUBlas<float> { ...@@ -62,19 +62,27 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) { cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that // Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack // appears in a lambda-expression, I can not use template parameter pack
// here. // here.
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False"); << (platform::TensorCoreAvailable() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
beta, C, Ctype, ldc)); lda, B, Btype, ldb, beta, C, Ctype, ldc));
});
#else #else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -162,10 +170,11 @@ struct CUBlas<platform::float16> { ...@@ -162,10 +170,11 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc, cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) { cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
bool use_tensor_op_math = dev_ctx->tensor_core_available(); bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) { if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} }
...@@ -173,13 +182,20 @@ struct CUBlas<platform::float16> { ...@@ -173,13 +182,20 @@ struct CUBlas<platform::float16> {
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
beta, C, Ctype, ldc, computeType, algo)); lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
});
#else #else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -207,10 +223,9 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -207,10 +223,9 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N); CUDA_R_32F, N);
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
lda, &beta, C, N); &alpha, B, ldb, A, lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -251,12 +266,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -251,12 +266,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
context_.CublasCall([&](cublasHandle_t handle) { N, M, K, &h_alpha, h_B, ldb, h_A, lda,
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &h_beta, h_C, N);
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
...@@ -280,10 +292,8 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -280,10 +292,8 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
context_.CublasCall([&](cublasHandle_t handle) { CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, &alpha, B, ldb, A, lda, &beta, C, ldc);
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -301,19 +311,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -301,19 +311,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
context_.CublasCall([&](cublasHandle_t handle) { CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, N, M, K, &alpha, B, ldb, A, lda, &beta, C,
B, ldb, A, lda, &beta, C, ldc); ldc);
});
} }
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const { T *y) const {
context_.CublasCall([&](cublasHandle_t handle) { CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
} }
template <> template <>
...@@ -323,9 +330,8 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, ...@@ -323,9 +330,8 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const { T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
context_.CublasCall([&](cublasHandle_t handle) { CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); &beta, C, 1);
});
} }
template <> template <>
...@@ -347,28 +353,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -347,28 +353,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) { if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available(); bool use_tensor_op_math = platform::TensorCoreAvailable();
if (use_tensor_op_math) { if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
strideC, batchCount, CUDA_R_32F, algo)); CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
}); };
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
context_.CublasCall([&](cublasHandle_t handle) { CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, N, M, K, &alpha, B, ldb, strideB, A, lda,
B, ldb, strideB, A, lda, strideA, &beta, C, strideA, &beta, C, ldc, strideC, batchCount);
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
...@@ -245,15 +245,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -245,15 +245,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
} }
...@@ -313,8 +306,7 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -313,8 +306,7 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback(); WaitStreamCallback();
cublas_handle_.reset(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
cublas_tensor_core_handle_.reset();
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -343,8 +335,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -343,8 +335,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
bool CUDADeviceContext::tensor_core_available() const { cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_tensor_core_handle_ != nullptr; return cublas_handle_;
} }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -210,6 +209,39 @@ class CudnnWorkspaceHandle { ...@@ -210,6 +209,39 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_; std::unique_ptr<std::lock_guard<std::mutex>> guard_;
}; };
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -230,25 +262,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -230,25 +262,8 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
/*! \brief Call cublas function safely. */ /*! \brief Return cublas handle in the device context. */
template <typename Callback> cublasHandle_t cublas_handle() const;
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
...@@ -267,6 +282,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -267,6 +282,7 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -278,6 +294,18 @@ class CUDADeviceContext : public DeviceContext { ...@@ -278,6 +294,18 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); } void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -285,9 +313,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -285,9 +313,7 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_; std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
...@@ -295,10 +321,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -295,10 +321,12 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_; int multi_process_;
int max_threads_per_mp_; int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe // StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); mutable std::mutex cublas_mtx_;
}; };
template <> template <>
......
...@@ -43,6 +43,9 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,6 +43,9 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册