未验证 提交 79a32b9d 编写于 作者: W Wilber 提交者: GitHub

support compile with cuda11 + cudnn8. test=develop (#3788)

上级 a8905fbb
......@@ -127,6 +127,10 @@ static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
#endif
#if CUDNN_VERSION_MIN(8, 0, 0)
case CUDNN_STATUS_VERSION_MISMATCH:
return "CUDNN_STATUS_VERSION_MISMATCH";
#endif
}
return "Unknown cudnn status";
......
......@@ -161,15 +161,17 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
search_func);
} else {
CUDNN_CHECK(
cudnnGetConvolutionForwardAlgorithm(this->handle_,
int requestedAlgoCount = 1;
int returnedAlgoCount;
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->preference_,
this->workspace_limit_bytes_,
&this->fwd_algo_));
requestedAlgoCount,
&returnedAlgoCount,
&this->algo_perf_));
this->fwd_algo_ = this->algo_perf_.algo;
}
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
......
......@@ -81,6 +81,7 @@ class CudnnConv2DBase {
cudaStream_t stream_;
cudnnHandle_t handle_;
cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnConvolutionFwdAlgoPerf_t algo_perf_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
cudnnTensorDescriptor_t bias_desc_;
......@@ -98,8 +99,6 @@ class CudnnConv2DBase {
const bool use_tensor_core_ = true;
const size_t workspace_limit_bytes_ = 4 * 1024 * 1024;
const cudnnConvolutionFwdPreference_t preference_ =
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
// For int8
Tensor temp_tensor_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册