diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv2d_transpose_cudnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ce94e0f04f14e1eae7e7d01280601cc72dea8c4 --- /dev/null +++ b/paddle/operators/conv2d_transpose_cudnn_op.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2d_transpose_op.h" + +namespace paddle { +namespace operators { + +class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { + public: + CudnnConv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv2DTransposeOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault(std::vector{1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp, + ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad, + ops::Conv2DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2d_transpose_cudnn, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_transpose_cudnn_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..61fcfb3bd8fa57f2c45fbf3a980dbe41041cff18 --- /dev/null +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -0,0 +1,240 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memory.h" +#include "paddle/operators/conv2d_transpose_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; +using CUDADeviceContext = platform::CUDADeviceContext; + +static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024; + +template +class CudnnConvTransposeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + // cudnn v5 does not support dilations + std::vector dilations = ctx.Attr>("dilations"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + // N, M, H, W + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + // N, C, O_h, O_w + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + // M, C, K_h, K_w + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims())); + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t algo; + auto handle = ctx.cuda_device_context().cudnn_handle(); + // Get the algorithm + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + + // get workspace size able to allocate + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &workspace_size_in_bytes)); + + // Allocate on GPU memory + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + + // ------------------- cudnn conv transpose forward --------------------- + T alpha = 1.0f, beta = 0.0f; + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc, + input_data, cudnn_conv_desc, algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +template +class CudnnConvTransposeGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + const T* input_data = input->data(); + const T* output_grad_data = output_grad->data(); + const T* filter_data = filter->data(); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + // cudnn v5 does not support dilations + std::vector dilations = ctx.Attr>("dilations"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + // Input: (N, M, H, W) + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + // Output: (N, C, O_H, O_W) + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims())); + // Filter (M, C, K_H, K_W) + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims())); + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionFwdAlgo_t data_algo; + cudnnConvolutionBwdFilterAlgo_t filter_algo; + size_t bwd_filter_ws_size, fwd_ws_size; + size_t workspace_size_in_bytes = 0; + size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + + auto handle = ctx.cuda_device_context().cudnn_handle(); + if (input_grad) { + // choose backward algorithm for data + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, data_algo, &fwd_ws_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size); + } + + if (filter_grad) { + // choose backward algorithm for filter + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &filter_algo)); + + // get workspace for backwards filter algorithm + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &bwd_filter_ws_size)); + workspace_size_in_bytes = + std::max(workspace_size_in_bytes, bwd_filter_ws_size); + } + + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv backward data --------------------- + // FIXME(typhoonzero): template type T may not be the same as cudnn call. + T alpha = 1.0f, beta = 0.0f; + if (input_grad) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, output_grad_data, + cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data)); + } + + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*filter_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + // Gradient with respect to the filter + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, + input_data, cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data)); + } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, + ops::CudnnConvTransposeOpKernel); +REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, + ops::CudnnConvTransposeGradOpKernel); diff --git a/paddle/operators/conv2dtranspose_op.cc b/paddle/operators/conv2d_transpose_op.cc similarity index 95% rename from paddle/operators/conv2dtranspose_op.cc rename to paddle/operators/conv2d_transpose_op.cc index c1b231906e2f172b6f9cee55f850d1a5ec6c3221..348527728bdd4ed60676d6e6e44c4e761b803096 100644 --- a/paddle/operators/conv2dtranspose_op.cc +++ b/paddle/operators/conv2d_transpose_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" namespace paddle { namespace operators { @@ -95,13 +95,13 @@ void Conv2DTransposeOpGrad::InferShape( } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, - ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, +REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp, + ops::Conv2DTransposeOpMaker, conv2d_transpose_grad, ops::Conv2DTransposeOpGrad); REGISTER_OP_CPU_KERNEL( - conv2dtranspose, + conv2d_transpose, ops::GemmConv2DTransposeKernel); REGISTER_OP_CPU_KERNEL( - conv2dtranspose_grad, + conv2d_transpose_grad, ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.cu b/paddle/operators/conv2d_transpose_op.cu similarity index 89% rename from paddle/operators/conv2dtranspose_op.cu rename to paddle/operators/conv2d_transpose_op.cu index 761bc1959e69be94f43571728e6b92a322558b99..931ac9eed294c4fe7c726d8cc2c4d9a39ec12828 100644 --- a/paddle/operators/conv2dtranspose_op.cu +++ b/paddle/operators/conv2d_transpose_op.cu @@ -12,13 +12,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - conv2dtranspose, + conv2d_transpose, ops::GemmConv2DTransposeKernel); REGISTER_OP_GPU_KERNEL( - conv2dtranspose_grad, + conv2d_transpose_grad, ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2d_transpose_op.h similarity index 99% rename from paddle/operators/conv2dtranspose_op.h rename to paddle/operators/conv2d_transpose_op.h index 8c70b3dcec1e26ab3d8a42d88040764c643b5ae6..cab7788227690621a0e5b744197b86c515bbef72 100644 --- a/paddle/operators/conv2dtranspose_op.h +++ b/paddle/operators/conv2d_transpose_op.h @@ -62,7 +62,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); // TODO(Zhuoyuan): Paddings can be added in future. - // groups will alway be disabled in conv2dtranspose. + // groups will alway be disabled in conv2d_transpose. const int batch_size = input->dims()[0]; const int m = input->dims()[1]; diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py similarity index 90% rename from python/paddle/v2/framework/tests/test_conv2dtranspose_op.py rename to python/paddle/v2/framework/tests/test_conv2d_transpose_op.py index 53604c58b70a534dff6b0a668d380fb8f10f53f6..999a0bdc629010d96a8db31b317ba7a65bf35773 100644 --- a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py @@ -45,23 +45,36 @@ class TestConv2dTransposeOp(OpTest): filter_ = np.random.random(self.filter_size).astype("float32") output = conv2dtranspose_forward_naive( input_, filter_, conv2dtranspose_param).astype('float32') - # print 'deconv output py', output, output.shape self.inputs = {'Input': input_, 'Filter': filter_} self.attrs = { 'strides': self.stride, 'paddings': self.pad, - # 'dilations': self.dilations + 'dilations': self.dilations } self.outputs = {'Output': output} def test_check_output(self): - print 'check output here' + print 'check output here for', self.op_type self.check_output() - def test_check_grad(self): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose" + + def test_check_grad_no_input(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) def test_check_grad_no_filter(self): self.check_grad( @@ -70,33 +83,15 @@ class TestConv2dTransposeOp(OpTest): max_relative_error=0.05, no_grad_set=set(['Filter'])) - def test_check_grad_no_input(self): + def test_check_grad(self): self.check_grad( - ['Filter'], - 'Output', - max_relative_error=0.05, - no_grad_set=set(['Input'])) + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) - def init_test_case(self): - self.pad = [0, 0] - self.stride = [1, 1] - self.dilations = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - f_c = self.input_size[1] - self.filter_size = [f_c, 6, 3, 3] +class TestCudnn(TestConv2dTransposeOp): def init_op_type(self): - self.op_type = "conv2dtranspose" - + self.op_type = "conv2d_transpose_cudnn" -""" -class TestCudnn(TestConv2dOp): - def init_group(self): - self.groups = 1 - - def init_op_type(self): - self.op_type = "conv_cudnn" -""" if __name__ == '__main__': unittest.main()