未验证 提交 79a32b9d 编写于 作者: W Wilber 提交者: GitHub

support compile with cuda11 + cudnn8. test=develop (#3788)

上级 a8905fbb
...@@ -127,6 +127,10 @@ static const char* CudnnGetErrorInfo(cudnnStatus_t status) { ...@@ -127,6 +127,10 @@ static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
return "CUDNN_STATUS_RUNTIME_IN_PROGRESS"; return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
case CUDNN_STATUS_RUNTIME_FP_OVERFLOW: case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"; return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
#endif
#if CUDNN_VERSION_MIN(8, 0, 0)
case CUDNN_STATUS_VERSION_MISMATCH:
return "CUDNN_STATUS_VERSION_MISMATCH";
#endif #endif
} }
return "Unknown cudnn status"; return "Unknown cudnn status";
......
...@@ -161,15 +161,17 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param, ...@@ -161,15 +161,17 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
search_func); search_func);
} else { } else {
CUDNN_CHECK( int requestedAlgoCount = 1;
cudnnGetConvolutionForwardAlgorithm(this->handle_, int returnedAlgoCount;
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(this->handle_,
this->input_desc_, this->input_desc_,
this->filter_desc_, this->filter_desc_,
this->conv_desc_, this->conv_desc_,
this->output_desc_, this->output_desc_,
this->preference_, requestedAlgoCount,
this->workspace_limit_bytes_, &returnedAlgoCount,
&this->fwd_algo_)); &this->algo_perf_));
this->fwd_algo_ = this->algo_perf_.algo;
} }
CUDNN_CHECK( CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_, cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
......
...@@ -81,6 +81,7 @@ class CudnnConv2DBase { ...@@ -81,6 +81,7 @@ class CudnnConv2DBase {
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t handle_; cudnnHandle_t handle_;
cudnnConvolutionFwdAlgo_t fwd_algo_; cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnConvolutionFwdAlgoPerf_t algo_perf_;
cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_; cudnnTensorDescriptor_t output_desc_;
cudnnTensorDescriptor_t bias_desc_; cudnnTensorDescriptor_t bias_desc_;
...@@ -98,8 +99,6 @@ class CudnnConv2DBase { ...@@ -98,8 +99,6 @@ class CudnnConv2DBase {
const bool use_tensor_core_ = true; const bool use_tensor_core_ = true;
const size_t workspace_limit_bytes_ = 4 * 1024 * 1024; const size_t workspace_limit_bytes_ = 4 * 1024 * 1024;
const cudnnConvolutionFwdPreference_t preference_ =
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
// For int8 // For int8
Tensor temp_tensor_; Tensor temp_tensor_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册