diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 4c9727391759b0c1865e9fc51288458e7786c878..7ad49de4eed5e26cdc24a7444ead9a50abf54453 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -251,7 +251,7 @@ struct SearchAlgorithm { args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count, perf_results.get())); algo = (perf_results.get())[best_algo_idx].algo; - workspace_size = GetWorkspaceSize(args, algo); + workspace_size = (perf_results.get())[best_algo_idx].memory; if (workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 @@ -502,7 +502,8 @@ struct SearchAlgorithm { args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS, &perf_count, perf_results.get())); algo = (perf_results.get())[best_algo_idx].algo; - workspace_size = GetWorkspaceSize(args, algo); + workspace_size = (perf_results.get())[best_algo_idx].memory; + if (workspace_size > workspace_size_limit) { workspace_size = workspace_size_limit; #if CUDNN_VERSION >= 8000