diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 22cbf680c0670552fb014043c69fcadc56863529..92435d7c417bc35987207060c5c99f1f89f8570b 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -159,9 +159,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); - // Allocate on GPU memory - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // Get cudnn workspace + cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { @@ -171,8 +170,6 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + i * group_offset_out)); } - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -315,10 +312,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); } // ------------------- cudnn conv workspace --------------------- - // Already on GPU - void* cudnn_workspace = nullptr; - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { @@ -347,8 +341,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { filter_grad_data + i * group_offset_filter)); } } - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 01fa9301d61045b8ae5d04dfec84e7a593b38059..3f8da69fc2f1a0ffeeb55328e1fd25134e57af5c 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -162,6 +162,7 @@ class CudnnHolder { paddle::memory::Free(place_, workspace_); } workspace_ = new_workspace; + workspace_len_ = required_len; } return workspace_ }