未验证 提交 4c61e141 编写于 作者: L limingshu 提交者: GitHub

GetWorkspaceSize trigger modfication in heuristic cudnn conv (#39184)

* first commit

* add more changes
上级 57b2033b
...@@ -251,7 +251,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -251,7 +251,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS,
&perf_count, perf_results.get())); &perf_count, perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo; algo = (perf_results.get())[best_algo_idx].algo;
workspace_size = GetWorkspaceSize(args, algo); workspace_size = (perf_results.get())[best_algo_idx].memory;
if (workspace_size > workspace_size_limit) { if (workspace_size > workspace_size_limit) {
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
...@@ -502,7 +502,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -502,7 +502,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS, args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get())); &perf_count, perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo; algo = (perf_results.get())[best_algo_idx].algo;
workspace_size = GetWorkspaceSize(args, algo); workspace_size = (perf_results.get())[best_algo_idx].memory;
if (workspace_size > workspace_size_limit) { if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit; workspace_size = workspace_size_limit;
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册