pool_cudnn_op.cu.cc 6.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
C
chengduoZH 已提交
18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode;
K
update  
Kexin Zhao 已提交
27 28
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
C
chengduoZH 已提交
29 30

template <typename T>
31
class PoolCUDNNOpKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
32 33 34
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
35
                   "It must use CUDAPlace.");
C
chengduoZH 已提交
36 37 38 39 40 41 42

    const Tensor *input = ctx.Input<Tensor>("X");
    Tensor *output = ctx.Output<Tensor>("Out");

    const T *input_data = input->data<T>();
    T *output_data = output->mutable_data<T>(ctx.GetPlace());

C
chengduoZH 已提交
43
    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
C
chengduoZH 已提交
44 45 46
    std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
47
    if (ctx.Attr<bool>("global_pooling")) {
C
chengduoZH 已提交
48
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
49
        paddings[i] = 0;
C
chengduoZH 已提交
50 51 52 53 54 55 56 57
        ksize[i] = static_cast<int>(input->dims()[i + 2]);
      }
    }

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_desc;
    ScopedPoolingDescriptor pool_desc;
C
chengduoZH 已提交
58 59 60 61 62 63 64
    DataLayout layout;

    if (strides.size() == 2U) {
      layout = DataLayout::kNCHW;
    } else {
      layout = DataLayout::kNCDHW;
    }
C
chengduoZH 已提交
65

C
chengduoZH 已提交
66 67 68 69
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize2int(input->dims()));
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize2int(output->dims()));
C
chengduoZH 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82

    PoolingMode pooling_mode;
    if (pooling_type == "max") {
      pooling_mode = PoolingMode::kMaximum;
    } else {
      pooling_mode = PoolingMode::kAverage;
    }

    cudnnPoolingDescriptor_t cudnn_pool_desc =
        pool_desc.descriptor(pooling_mode, ksize, paddings, strides);

    // ------------------- cudnn pool algorithm ---------------------
    auto handle = ctx.cuda_device_context().cudnn_handle();
K
update  
Kexin Zhao 已提交
83
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
C
chengduoZH 已提交
84 85 86 87 88 89 90
    PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward(
        handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
        cudnn_output_desc, output_data));
  }
};

template <typename T>
91
class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
92 93 94
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
95
                   "It must use CUDAPlace.");
C
chengduoZH 已提交
96 97 98 99 100 101 102

    const Tensor *input = ctx.Input<Tensor>("X");
    const Tensor *output = ctx.Input<Tensor>("Out");
    const Tensor *output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Out"));
    Tensor *input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

C
chengduoZH 已提交
103
    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
C
chengduoZH 已提交
104 105 106 107
    std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");

C
chengduoZH 已提交
108
    if (ctx.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
109 110
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
111
        ksize[i] = static_cast<int>(input->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
112
      }
C
chengduoZH 已提交
113 114 115 116 117 118 119 120 121 122
    }

    const T *input_data = input->data<T>();
    const T *output_data = output->data<T>();
    const T *output_grad_data = output_grad->data<T>();

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_desc;
    ScopedPoolingDescriptor pool_desc;
C
chengduoZH 已提交
123 124 125 126 127 128 129
    DataLayout layout;

    if (strides.size() == 2U) {
      layout = DataLayout::kNCHW;
    } else {
      layout = DataLayout::kNCDHW;
    }
C
chengduoZH 已提交
130

C
chengduoZH 已提交
131 132 133 134
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize2int(input->dims()));
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize2int(output->dims()));
C
chengduoZH 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147

    PoolingMode pooling_mode;
    if (pooling_type == "max") {
      pooling_mode = PoolingMode::kMaximum;
    } else {
      pooling_mode = PoolingMode::kAverage;
    }

    cudnnPoolingDescriptor_t cudnn_pool_desc =
        pool_desc.descriptor(pooling_mode, ksize, paddings, strides);

    // ------------------- cudnn pool algorithm ---------------------
    auto handle = ctx.cuda_device_context().cudnn_handle();
K
update  
Kexin Zhao 已提交
148
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
C
chengduoZH 已提交
149 150
    if (input_grad) {
      T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
C
chengduoZH 已提交
151
      // Because beta is zero, it is unnecessary to reset input_grad.
C
chengduoZH 已提交
152 153 154

      PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
          handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
155 156
          cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
          &beta, cudnn_input_desc, input_grad_data));
C
chengduoZH 已提交
157 158 159 160 161 162 163 164
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
165
namespace plat = paddle::platform;
C
chengduoZH 已提交
166

K
Kexin Zhao 已提交
167
REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
168
                   ops::PoolCUDNNOpKernel<float>,
K
Kexin Zhao 已提交
169 170 171
                   ops::PoolCUDNNOpKernel<double>,
                   ops::PoolCUDNNOpKernel<plat::float16>);
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
172 173 174
                   ops::PoolCUDNNGradOpKernel<float>,
                   ops::PoolCUDNNGradOpKernel<double>);

K
Kexin Zhao 已提交
175
REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
176 177
                   ops::PoolCUDNNOpKernel<float>,
                   ops::PoolCUDNNOpKernel<double>);
K
Kexin Zhao 已提交
178
REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace,
179 180
                   ops::PoolCUDNNGradOpKernel<float>,
                   ops::PoolCUDNNGradOpKernel<double>);