diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index e8f371cb4877f343d108e8528345be03cd9b354b..b22f28fbbe3ce8ce178a3d9c17a048817cb750e7 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -216,6 +216,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { perf_results.get())); algo = (perf_results.get())[best_algo_idx].algo; VLOG(3) << "cuDNN forward algo " << algo; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &workspace_size_in_bytes)); + if (workspace_size_in_bytes > workspace_size_limit) + workspace_size_limit = workspace_size_in_bytes; } else { std::function search_func = [&]() -> cudnnConvolutionFwdAlgo_t {