roi_pool_op.cu 9.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/roi_pool_op.h"
D
dzhwinter 已提交
16
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaox 已提交
17 18 19 20

namespace paddle {
namespace operators {

W
wanghaox 已提交
21
using Tensor = framework::Tensor;
22
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
23

W
wanghaox 已提交
24 25
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
W
wanghaox 已提交
26

W
wanghaox 已提交
27 28 29
static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
W
wanghaox 已提交
30
}
W
wanghaox 已提交
31

32
template <typename T>
33 34 35 36 37
__global__ void GPUROIPoolForward(
    const int nthreads, const T* input_data, const int64_t* input_rois,
    const float spatial_scale, const int channels, const int height,
    const int width, const int pooled_height, const int pooled_width,
    int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
38 39 40 41 42 43 44 45 46
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const int64_t* offset_input_rois = input_rois + n * kROISize;
47 48 49 50 51
    int roi_batch_ind = roi_batch_id_data[n];
    int roi_start_w = round(offset_input_rois[0] * spatial_scale);
    int roi_start_h = round(offset_input_rois[1] * spatial_scale);
    int roi_end_w = round(offset_input_rois[2] * spatial_scale);
    int roi_end_h = round(offset_input_rois[3] * spatial_scale);
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
    int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
    int maxidx = -1;
    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_data_index = h * width + w;
        if (offset_input_data[input_data_index] > maxval) {
          maxval = offset_input_data[input_data_index];
          maxidx = input_data_index;
W
wanghaox 已提交
79
        }
W
wanghaox 已提交
80
      }
81 82 83 84
    }
    output_data[index] = maxval;
    if (argmax_data) {
      argmax_data[index] = maxidx;
W
wanghaox 已提交
85 86
    }
  }
87
}
W
wanghaox 已提交
88 89

template <typename T>
W
wanghaox 已提交
90
__global__ void GPUROIPoolBackward(
91 92 93
    const int nthreads, const int64_t* input_rois, const T* output_grad,
    const int64_t* argmax_data, const int num_rois, const float spatial_scale,
    const int channels, const int height, const int width,
94 95
    const int pooled_height, const int pooled_width, int* roi_batch_id_data,
    T* input_grad) {
96 97 98 99 100 101 102 103
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

104
    int roi_batch_ind = roi_batch_id_data[n];
105 106 107 108 109 110 111 112 113 114
    int input_offset = (roi_batch_ind * channels + c) * height * width;
    int output_offset = (n * channels + c) * pooled_height * pooled_width;
    const T* offset_output_grad = output_grad + output_offset;
    T* offset_input_grad = input_grad + input_offset;
    const int64_t* offset_argmax_data = argmax_data + output_offset;

    int argmax = offset_argmax_data[ph * pooled_width + pw];
    if (argmax != -1) {
      platform::CudaAtomicAdd(
          offset_input_grad + argmax,
W
wanghaox 已提交
115 116 117
          static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
    }
  }
118
}
W
wanghaox 已提交
119 120

template <typename Place, typename T>
W
wanghaox 已提交
121
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
122 123 124
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
125
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
126 127 128 129 130 131 132 133
    auto* out = ctx.Output<Tensor>("Out");
    auto* argmax = ctx.Output<Tensor>("Argmax");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
134
    int batch_size = in_dims[0];
W
wanghaox 已提交
135 136 137 138 139
    auto in_stride = framework::stride(in_dims);
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

140
    int rois_num = rois->dims()[0];
141
    if (rois_num == 0) return;
W
wanghaox 已提交
142 143

    int output_size = out->numel();
W
wanghaox 已提交
144 145
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;
W
wanghaox 已提交
146

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    framework::Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({rois_num});
    int* roi_batch_id_data =
        roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
    auto rois_lod = rois->lod().back();
    int rois_batch_size = rois_lod.size() - 1;
    PADDLE_ENFORCE_EQ(
        rois_batch_size, batch_size,
        "The rois_batch_size and imgs batch_size must be the same.");
    int rois_num_with_lod = rois_lod[rois_batch_size];
    PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
                      "The rois_num from input and lod must be the same.");
    for (int n = 0; n < rois_batch_size; ++n) {
      for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
        roi_batch_id_data[i] = n;
      }
    }

    framework::Tensor roi_batch_id_list_gpu;
    framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                          ctx.device_context(), &roi_batch_id_list_gpu);

169 170 171 172
    GPUROIPoolForward<
        T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
        output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
        channels, height, width, pooled_height, pooled_width,
173
        roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
174
        argmax->mutable_data<int64_t>(ctx.GetPlace()));
W
wanghaox 已提交
175 176 177 178
  }
};

template <typename Place, typename T>
W
wanghaox 已提交
179
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
180 181 182
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
183
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
184 185
    auto* argmax = ctx.Input<Tensor>("Argmax");

186 187
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
W
wanghaox 已提交
188 189 190 191 192

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

193
    int rois_num = rois->dims()[0];
W
wanghaox 已提交
194 195 196 197 198
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (x_grad) {
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
      framework::Tensor roi_batch_id_list;
      roi_batch_id_list.Resize({rois_num});
      int* roi_batch_id_data =
          roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          roi_batch_id_data[i] = n;
        }
      }
      framework::Tensor roi_batch_id_list_gpu;
      framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                            ctx.device_context(), &roi_batch_id_list_gpu);

W
wanghaox 已提交
214 215
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
Q
QI JUN 已提交
216
      set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
W
wanghaox 已提交
217 218

      int output_grad_size = out_grad->numel();
W
wanghaox 已提交
219 220
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;
W
wanghaox 已提交
221 222

      if (output_grad_size > 0) {
223 224 225 226 227
        GPUROIPoolBackward<
            T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
            output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
            argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
            width, pooled_height, pooled_width,
228
            roi_batch_id_list_gpu.data<int>(),
229 230
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
W
wanghaox 已提交
231 232 233 234 235 236 237 238
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
239 240 241 242 243
REGISTER_OP_CUDA_KERNEL(
    roi_pool,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
W
wanghaox 已提交
244
    roi_pool_grad,
Q
QI JUN 已提交
245 246
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);