diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 82fff68e7557b3f0b44e6faf2a50e5a0ecbba589..c24cb14a6160df71b1d847af0e19b016ec85342f 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -100,9 +100,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); - // Allocate on GPU memory - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // Get cudnn workspace + cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv transpose forward --------------------- int input_offset = input->numel() / input->dims()[0] / groups; @@ -116,9 +115,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { algo, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data + output_offset * g)); } - - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -207,10 +203,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { } // ------------------- cudnn conv workspace --------------------- - // Already on GPU - void* cudnn_workspace = nullptr; - platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); - cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // Get cudnn workspace + void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. int input_offset = input->numel() / input->dims()[0] / groups; @@ -245,9 +239,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { filter_grad_data + filter_offset * g)); } } - - // Release the cudnn workspace - paddle::memory::Free(gpu, cudnn_workspace); } };