未验证 提交 50f149a4 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix cudnn workspace size problem during inference. (#26021)

test=develop
上级 1f74b94d
......@@ -216,6 +216,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
VLOG(3) << "cuDNN forward algo " << algo;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
if (workspace_size_in_bytes > workspace_size_limit)
workspace_size_limit = workspace_size_in_bytes;
} else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func =
[&]() -> cudnnConvolutionFwdAlgo_t {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册