conv2d_transpose_cudnn_op.cu 10.6 KB
Newer Older
Z
zchen0211 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memory.h"
Z
zchen0211 已提交
18
#include "paddle/operators/conv2d_transpose_op.h"
Z
zchen0211 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
#include "paddle/platform/assert.h"
#include "paddle/platform/cudnn_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;

static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;

template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");
    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
    auto* output = ctx.Output<Tensor>("Output");

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int user_workspace_size = ctx.Attr<int>("workspace_size_MB");

    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_desc;
    ScopedFilterDescriptor filter_desc;
    ScopedConvolutionDescriptor conv_desc;
    DataLayout layout = DataLayout::kNCHW;

    // N, M, H, W
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize2int(input->dims()));
    // N, C, O_h, O_w
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize2int(output->dims()));
    // M, C, K_h, K_w
    cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
        layout, framework::vectorize2int(filter->dims()));
    cudnnConvolutionDescriptor_t cudnn_conv_desc =
        conv_desc.descriptor<T>(paddings, strides, dilations);

    // ------------------- cudnn conv workspace ---------------------
    void* cudnn_workspace = nullptr;
    size_t workspace_size_in_bytes;  // final workspace to allocate.
    size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
    if (user_workspace_size > 0) {
      workspace_size_limit = user_workspace_size * 1024 * 1024;
    }
    // ------------------- cudnn conv algorithm ---------------------
Z
zchen0211 已提交
79
    cudnnConvolutionBwdDataAlgo_t algo;
Z
zchen0211 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92
    auto handle = ctx.cuda_device_context().cudnn_handle();
    // Get the algorithm
    PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
        handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
        // dxDesc: Handle to the previously initialized output tensor
        // descriptor.
        cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
        workspace_size_limit, &algo));

    // get workspace size able to allocate
    PADDLE_ENFORCE(
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
            handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
Z
zchen0211 已提交
93
            cudnn_output_desc, algo, &workspace_size_in_bytes));
Z
zchen0211 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

    // Allocate on GPU memory
    platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
    cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);

    // ------------------- cudnn conv transpose forward ---------------------
    T alpha = 1.0f, beta = 0.0f;
    PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
        handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc,
        input_data, cudnn_conv_desc, algo, cudnn_workspace,
        workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));

    // Release the cudnn workspace
    paddle::memory::Free(gpu, cudnn_workspace);
  }
};

template <typename T>
class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use GPUPlace.");
    auto input = ctx.Input<Tensor>("Input");
    auto filter = ctx.Input<Tensor>("Filter");
    auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
    auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
    const T* input_data = input->data<T>();
    const T* output_grad_data = output_grad->data<T>();
    const T* filter_data = filter->data<T>();

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int user_workspace_size = ctx.Attr<int>("workspace_size_MB");

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
Z
zchen0211 已提交
133
    ScopedTensorDescriptor output_desc;
Z
zchen0211 已提交
134 135 136 137
    ScopedFilterDescriptor filter_desc;
    ScopedConvolutionDescriptor conv_desc;
    DataLayout layout = DataLayout::kNCHW;

Z
zchen0211 已提交
138
    // Input: (N, M, H, W)
Z
zchen0211 已提交
139
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
Z
zchen0211 已提交
140 141 142 143 144
        layout, framework::vectorize2int(input->dims()));
    // Output: (N, C, O_H, O_W)
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize2int(output_grad->dims()));
    // Filter (M, C, K_H, K_W)
Z
zchen0211 已提交
145
    cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
Z
zchen0211 已提交
146
        layout, framework::vectorize2int(filter->dims()));
Z
zchen0211 已提交
147 148 149 150 151

    cudnnConvolutionDescriptor_t cudnn_conv_desc =
        conv_desc.descriptor<T>(paddings, strides, dilations);

    // ------------------- cudnn backward algorithm ---------------------
Z
zchen0211 已提交
152
    cudnnConvolutionFwdAlgo_t data_algo;
Z
zchen0211 已提交
153
    cudnnConvolutionBwdFilterAlgo_t filter_algo;
Z
zchen0211 已提交
154 155
    size_t bwd_filter_ws_size, fwd_ws_size;
    size_t workspace_size_in_bytes = 0;
Z
zchen0211 已提交
156 157 158 159 160 161 162
    size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
    if (user_workspace_size > 0) {
      workspace_size_limit = user_workspace_size * 1024 * 1024;
    }

    auto handle = ctx.cuda_device_context().cudnn_handle();
    if (input_grad) {
Z
zchen0211 已提交
163 164 165 166 167 168 169 170 171
      // choose backward algorithm for data
      PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
          handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
          cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
          workspace_size_limit, &data_algo));
      PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
          handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
          cudnn_input_desc, data_algo, &fwd_ws_size));
      workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size);
Z
zchen0211 已提交
172 173 174
    }

    if (filter_grad) {
Z
zchen0211 已提交
175
      // choose backward algorithm for filter
Z
zchen0211 已提交
176 177
      PADDLE_ENFORCE(
          platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
Z
zchen0211 已提交
178
              handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
Z
zchen0211 已提交
179 180 181 182
              cudnn_filter_desc,
              CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
              workspace_size_limit, &filter_algo));

Z
zchen0211 已提交
183
      // get workspace for backwards filter algorithm
Z
zchen0211 已提交
184 185
      PADDLE_ENFORCE(
          platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
Z
zchen0211 已提交
186 187 188 189
              handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
              cudnn_filter_desc, filter_algo, &bwd_filter_ws_size));
      workspace_size_in_bytes =
          std::max(workspace_size_in_bytes, bwd_filter_ws_size);
Z
zchen0211 已提交
190
    }
Z
zchen0211 已提交
191

Z
zchen0211 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204
    // ------------------- cudnn conv workspace ---------------------
    // Already on GPU
    void* cudnn_workspace = nullptr;
    platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
    cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
    // ------------------- cudnn conv backward data ---------------------
    // FIXME(typhoonzero): template type T may not be the same as cudnn call.
    T alpha = 1.0f, beta = 0.0f;
    if (input_grad) {
      T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
      auto t = framework::EigenVector<T>::Flatten(*input_grad);
      t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
          t.constant(static_cast<T>(0));
Z
zchen0211 已提交
205 206 207 208 209 210

      PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
          handle, &alpha, cudnn_output_desc, output_grad_data,
          cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
          cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
          input_grad_data));
Z
zchen0211 已提交
211
    }
Z
zchen0211 已提交
212

Z
zchen0211 已提交
213 214 215 216 217 218
    // ------------------- cudnn conv backward filter ---------------------
    if (filter_grad) {
      T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
      auto t = framework::EigenVector<T>::Flatten(*filter_grad);
      t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
          t.constant(static_cast<T>(0));
Z
zchen0211 已提交
219 220 221 222 223
      // Gradient with respect to the filter
      PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
          handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
          input_data, cudnn_conv_desc, filter_algo, cudnn_workspace,
          workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data));
Z
zchen0211 已提交
224 225 226 227 228 229 230 231 232 233 234
    }
    // Release the cudnn workspace
    paddle::memory::Free(gpu, cudnn_workspace);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

Z
zchen0211 已提交
235
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
Z
zchen0211 已提交
236
                       ops::CudnnConvTransposeOpKernel<float>);
Z
zchen0211 已提交
237
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
Z
zchen0211 已提交
238
                       ops::CudnnConvTransposeGradOpKernel<float>);