From 2d44a2ec5a55699252bb64aa4a57186705c73d5f Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 30 Oct 2017 19:37:45 -0700 Subject: [PATCH] deconv cudnn --- paddle/operators/conv2dtranspose_cudnn_op.cc | 50 ++++ paddle/operators/conv2dtranspose_cudnn_op.cu | 276 +++++++++++++++++++ 2 files changed, 326 insertions(+) create mode 100644 paddle/operators/conv2dtranspose_cudnn_op.cc create mode 100644 paddle/operators/conv2dtranspose_cudnn_op.cu diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cc b/paddle/operators/conv2dtranspose_cudnn_op.cc new file mode 100644 index 00000000000..72c470389cb --- /dev/null +++ b/paddle/operators/conv2dtranspose_cudnn_op.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace paddle { +namespace operators { + +class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { + public: + CudnnConv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv2DTransposeOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault(std::vector{1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2dtranspose_cudnn, ops::Conv2DTransposeOp, + ops::CudnnConv2DTransposeOpMaker, conv2dtranspose_cudnn_grad, + ops::Conv2DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_cudnn, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_cudnn_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2dtranspose_cudnn_op.cu new file mode 100644 index 00000000000..e9bad8c5174 --- /dev/null +++ b/paddle/operators/conv2dtranspose_cudnn_op.cu @@ -0,0 +1,276 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memory.h" +#include "paddle/operators/conv2d_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; +using CUDADeviceContext = platform::CUDADeviceContext; + +static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; + +template +class CudnnConvTransposeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + // N, M, H, W + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + // N, C, O_h, O_w + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + // M, C, K_h, K_w + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims())); + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; // M + int input_height = input->dims()[2]; // H + int input_width = input->dims()[3]; // W + int output_channels = output->dims()[1]; // C + int output_height = output->dims()[2]; // O_H + int output_width = output->dims()[3]; // O_W + + // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t tmp_size; + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionBwdAlgo_t algo; + auto handle = ctx.cuda_device_context().cudnn_handle(); + // Get the algorithm + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + + // get workspace size able to allocate + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + + // Allocate on GPU memory + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + + // ------------------- cudnn conv transpose forward --------------------- + T alpha = 1.0f, beta = 0.0f; + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc, + input_data, cudnn_conv_desc, algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +/* +template +class CudnnConvTransposeGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const T* input_data = input->data(); + const T* output_grad_data = output_grad->data(); + const T* filter_data = filter->data(); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_grad_desc; + ScopedTensorDescriptor input_grad_desc; + + ScopedFilterDescriptor filter_desc; + ScopedFilterDescriptor filter_grad_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_grad_desc = + output_grad_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; + cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; + int input_height = input->dims()[2]; + int input_width = input->dims()[3]; + int output_grad_channels = filter->dims()[0]; + int output_grad_height = output_grad->dims()[2]; + int output_grad_width = output_grad->dims()[3]; + + int group_offset_in = input_channels / groups * input_height * input_width; + int group_offset_out = + output_grad_channels / groups * output_grad_height * output_grad_width; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t data_algo; + cudnnConvolutionBwdFilterAlgo_t filter_algo; + size_t workspace_size_in_bytes = 0, tmp_size = 0; + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + + auto handle = ctx.cuda_device_context().cudnn_handle(); + if (input_grad) { + cudnn_input_grad_desc = input_grad_desc.descriptor( + layout, framework::vectorize2int(input_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, + // dyDesc: Handle to the previously initialized input differential + // tensor descriptor. + cudnn_output_grad_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_input_grad_desc, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_output_grad_desc, + cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + + if (filter_grad) { + cudnn_filter_grad_desc = filter_grad_desc.descriptor( + layout, framework::vectorize2int(filter_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &filter_algo)); + + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv backward data --------------------- + // FIXME(typhoonzero): template type T may not be the same as cudnn call. + T alpha = 1.0f, beta = 0.0f; + if (input_grad) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_input_grad_desc, input_grad_data + i * group_offset_in)); + } + } + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*filter_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_grad_desc, + filter_grad_data + i * group_offset_filter)); + } + } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; +*/ + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn, + ops::CudnnConvTransposeOpKernel); +// REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, +// ops::CudnnConvTransposeGradOpKernel); -- GitLab