Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
79a32b9d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
79a32b9d
编写于
6月 15, 2020
作者:
W
Wilber
提交者:
GitHub
6月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support compile with cuda11 + cudnn8. test=develop (#3788)
上级
a8905fbb
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
16 addition
and
11 deletion
+16
-11
lite/backends/cuda/cuda_utils.h
lite/backends/cuda/cuda_utils.h
+4
-0
lite/backends/cuda/math/cudnn_conv.cc
lite/backends/cuda/math/cudnn_conv.cc
+11
-9
lite/backends/cuda/math/cudnn_conv.h
lite/backends/cuda/math/cudnn_conv.h
+1
-2
未找到文件。
lite/backends/cuda/cuda_utils.h
浏览文件 @
79a32b9d
...
@@ -127,6 +127,10 @@ static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
...
@@ -127,6 +127,10 @@ static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
return
"CUDNN_STATUS_RUNTIME_IN_PROGRESS"
;
return
"CUDNN_STATUS_RUNTIME_IN_PROGRESS"
;
case
CUDNN_STATUS_RUNTIME_FP_OVERFLOW
:
case
CUDNN_STATUS_RUNTIME_FP_OVERFLOW
:
return
"CUDNN_STATUS_RUNTIME_FP_OVERFLOW"
;
return
"CUDNN_STATUS_RUNTIME_FP_OVERFLOW"
;
#endif
#if CUDNN_VERSION_MIN(8, 0, 0)
case
CUDNN_STATUS_VERSION_MISMATCH
:
return
"CUDNN_STATUS_VERSION_MISMATCH"
;
#endif
#endif
}
}
return
"Unknown cudnn status"
;
return
"Unknown cudnn status"
;
...
...
lite/backends/cuda/math/cudnn_conv.cc
浏览文件 @
79a32b9d
...
@@ -161,15 +161,17 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
...
@@ -161,15 +161,17 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
search_func
);
search_func
);
}
else
{
}
else
{
CUDNN_CHECK
(
int
requestedAlgoCount
=
1
;
cudnnGetConvolutionForwardAlgorithm
(
this
->
handle_
,
int
returnedAlgoCount
;
this
->
input_desc_
,
CUDNN_CHECK
(
cudnnGetConvolutionForwardAlgorithm_v7
(
this
->
handle_
,
this
->
filter_desc_
,
this
->
input_desc_
,
this
->
conv_desc_
,
this
->
filter_desc_
,
this
->
output_desc_
,
this
->
conv_desc_
,
this
->
preference_
,
this
->
output_desc_
,
this
->
workspace_limit_bytes_
,
requestedAlgoCount
,
&
this
->
fwd_algo_
));
&
returnedAlgoCount
,
&
this
->
algo_perf_
));
this
->
fwd_algo_
=
this
->
algo_perf_
.
algo
;
}
}
CUDNN_CHECK
(
CUDNN_CHECK
(
cudnnGetConvolutionForwardWorkspaceSize
(
this
->
handle_
,
cudnnGetConvolutionForwardWorkspaceSize
(
this
->
handle_
,
...
...
lite/backends/cuda/math/cudnn_conv.h
浏览文件 @
79a32b9d
...
@@ -81,6 +81,7 @@ class CudnnConv2DBase {
...
@@ -81,6 +81,7 @@ class CudnnConv2DBase {
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
handle_
;
cudnnHandle_t
handle_
;
cudnnConvolutionFwdAlgo_t
fwd_algo_
;
cudnnConvolutionFwdAlgo_t
fwd_algo_
;
cudnnConvolutionFwdAlgoPerf_t
algo_perf_
;
cudnnTensorDescriptor_t
input_desc_
;
cudnnTensorDescriptor_t
input_desc_
;
cudnnTensorDescriptor_t
output_desc_
;
cudnnTensorDescriptor_t
output_desc_
;
cudnnTensorDescriptor_t
bias_desc_
;
cudnnTensorDescriptor_t
bias_desc_
;
...
@@ -98,8 +99,6 @@ class CudnnConv2DBase {
...
@@ -98,8 +99,6 @@ class CudnnConv2DBase {
const
bool
use_tensor_core_
=
true
;
const
bool
use_tensor_core_
=
true
;
const
size_t
workspace_limit_bytes_
=
4
*
1024
*
1024
;
const
size_t
workspace_limit_bytes_
=
4
*
1024
*
1024
;
const
cudnnConvolutionFwdPreference_t
preference_
=
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST
;
// For int8
// For int8
Tensor
temp_tensor_
;
Tensor
temp_tensor_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录