From d5f74b73082db157beab3d62f1783d902397bd26 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 30 Aug 2018 17:53:28 +0800 Subject: [PATCH] use CudnnHolder in conv_transpose_cudnn_op --- .../operators/conv_transpose_cudnn_op.cu.cc | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 82fff68e75..c24cb14a61 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -100,9 +100,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); - // Allocate on GPU memory - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // Get cudnn workspace + cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv transpose forward --------------------- int input_offset = input->numel() / input->dims()[0] / groups; @@ -116,9 +115,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + output_offset * g)); } - - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -207,10 +203,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { } // ------------------- cudnn conv workspace --------------------- - // Already on GPU - void* cudnn_workspace = nullptr; - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // Get cudnn workspace + void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. int input_offset = input->numel() / input->dims()[0] / groups; @@ -245,9 +239,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { filter_grad_data + filter_offset * g)); } } - - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } }; -- GitLab