roi_pool_op.cu 8.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/roi_pool_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
W
wanghaox 已提交
17 18 19 20

namespace paddle {
namespace operators {

W
wanghaox 已提交
21 22
using Tensor = framework::Tensor;

W
wanghaox 已提交
23 24 25
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int kROISize = 5;
W
wanghaox 已提交
26

W
wanghaox 已提交
27 28 29
static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
W
wanghaox 已提交
30
}
W
wanghaox 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename T>
__global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
                                  const int64_t* input_rois,
                                  const float spatial_scale, const int channels,
                                  const int height, const int width,
                                  const int pooled_height,
                                  const int pooled_width, T* output_data,
                                  int64_t* argmax_data) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const int64_t* offset_input_rois = input_rois + n * kROISize;
    int roi_batch_ind = offset_input_rois[0];
    int roi_start_w = round(offset_input_rois[1] * spatial_scale);
    int roi_start_h = round(offset_input_rois[2] * spatial_scale);
    int roi_end_w = round(offset_input_rois[3] * spatial_scale);
    int roi_end_h = round(offset_input_rois[4] * spatial_scale);

    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
    int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
    int maxidx = -1;
    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_data_index = h * width + w;
        if (offset_input_data[input_data_index] > maxval) {
          maxval = offset_input_data[input_data_index];
          maxidx = input_data_index;
W
wanghaox 已提交
81
        }
W
wanghaox 已提交
82
      }
83 84 85 86
    }
    output_data[index] = maxval;
    if (argmax_data) {
      argmax_data[index] = maxidx;
W
wanghaox 已提交
87 88
    }
  }
89
}
W
wanghaox 已提交
90 91

template <typename T>
W
wanghaox 已提交
92
__global__ void GPUROIPoolBackward(
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    const int nthreads, const int64_t* input_rois, const T* output_grad,
    const int64_t* argmax_data, const int num_rois, const float spatial_scale,
    const int channels, const int height, const int width,
    const int pooled_height, const int pooled_width, T* input_grad) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const int64_t* offset_input_rois = input_rois + n * kROISize;
    int roi_batch_ind = offset_input_rois[0];
    int input_offset = (roi_batch_ind * channels + c) * height * width;
    int output_offset = (n * channels + c) * pooled_height * pooled_width;
    const T* offset_output_grad = output_grad + output_offset;
    T* offset_input_grad = input_grad + input_offset;
    const int64_t* offset_argmax_data = argmax_data + output_offset;

    int argmax = offset_argmax_data[ph * pooled_width + pw];
    if (argmax != -1) {
      platform::CudaAtomicAdd(
          offset_input_grad + argmax,
W
wanghaox 已提交
117 118 119
          static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
    }
  }
120
}
W
wanghaox 已提交
121 122

template <typename Place, typename T>
W
wanghaox 已提交
123
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
124 125 126
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
W
wanghaox 已提交
127
    auto* rois = ctx.Input<Tensor>("ROIs");
W
wanghaox 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140
    auto* out = ctx.Output<Tensor>("Out");
    auto* argmax = ctx.Output<Tensor>("Argmax");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
    auto in_stride = framework::stride(in_dims);
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

W
wanghaox 已提交
141
    size_t rois_num = rois->dims()[0];
142
    if (rois_num == 0) return;
W
wanghaox 已提交
143 144

    int output_size = out->numel();
W
wanghaox 已提交
145 146
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;
W
wanghaox 已提交
147

148 149 150 151 152 153
    GPUROIPoolForward<
        T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
        output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
        channels, height, width, pooled_height, pooled_width,
        out->mutable_data<T>(ctx.GetPlace()),
        argmax->mutable_data<int64_t>(ctx.GetPlace()));
W
wanghaox 已提交
154 155 156 157
  }
};

template <typename Place, typename T>
W
wanghaox 已提交
158
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
159 160 161
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
W
wanghaox 已提交
162
    auto* rois = ctx.Input<Tensor>("ROIs");
W
wanghaox 已提交
163 164
    auto* argmax = ctx.Input<Tensor>("Argmax");

165 166
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
W
wanghaox 已提交
167 168 169 170 171

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

W
wanghaox 已提交
172
    size_t rois_num = rois->dims()[0];
W
wanghaox 已提交
173 174 175 176 177 178 179
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (x_grad) {
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
Q
QI JUN 已提交
180
      set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
W
wanghaox 已提交
181 182

      int output_grad_size = out_grad->numel();
W
wanghaox 已提交
183 184
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;
W
wanghaox 已提交
185 186

      if (output_grad_size > 0) {
187 188 189 190 191 192 193
        GPUROIPoolBackward<
            T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
            output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
            argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
            width, pooled_height, pooled_width,
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
W
wanghaox 已提交
194 195 196 197 198 199 200 201
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
202 203 204 205 206
REGISTER_OP_CUDA_KERNEL(
    roi_pool,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
W
wanghaox 已提交
207
    roi_pool_grad,
Q
QI JUN 已提交
208 209
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);