提交 d5f74b73 编写于 作者: F fengjiayi

use CudnnHolder in conv_transpose_cudnn_op

上级 15cc9128
...@@ -100,9 +100,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -100,9 +100,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes)); cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory // Get cudnn workspace
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward --------------------- // ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups; int input_offset = input->numel() / input->dims()[0] / groups;
...@@ -116,9 +115,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -116,9 +115,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta, algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g)); cudnn_output_desc, output_data + output_offset * g));
} }
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
} }
}; };
...@@ -207,10 +203,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -207,10 +203,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
} }
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
// Already on GPU // Get cudnn workspace
void* cudnn_workspace = nullptr; void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call. // FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups; int input_offset = input->numel() / input->dims()[0] / groups;
...@@ -245,9 +239,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -245,9 +239,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
filter_grad_data + filter_offset * g)); filter_grad_data + filter_offset * g));
} }
} }
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册