提交 a349bee6 编写于 作者: Z zchen0211

deconv2d cudnn

上级 e80489a4
...@@ -79,13 +79,13 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -79,13 +79,13 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr; void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_in_bytes; // final workspace to allocate.
size_t tmp_size;
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) { if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionBwdAlgo_t algo; // cudnnConvolutionBwdAlgo_t algo;
cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
// Get the algorithm // Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
...@@ -99,8 +99,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -99,8 +99,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &tmp_size)); cudnn_output_desc, algo, &workspace_size_in_bytes));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); // workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
// Allocate on GPU memory // Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace()); platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册