roi_pool_op.cu 9.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/roi_pool_op.h"
D
dzhwinter 已提交
16
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaox 已提交
17 18 19 20

namespace paddle {
namespace operators {

W
wanghaox 已提交
21
using Tensor = framework::Tensor;
22
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
23

W
wanghaox 已提交
24 25
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
W
wanghaox 已提交
26

W
wanghaox 已提交
27 28 29
static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
W
wanghaox 已提交
30
}
W
wanghaox 已提交
31

32
template <typename T>
33 34 35 36 37
__global__ void GPUROIPoolForward(
    const int nthreads, const T* input_data, const int64_t* input_rois,
    const float spatial_scale, const int channels, const int height,
    const int width, const int pooled_height, const int pooled_width,
    int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
38 39 40
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
41 42 43 44
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
45 46

    const int64_t* offset_input_rois = input_rois + n * kROISize;
47 48 49 50 51
    int roi_batch_ind = roi_batch_id_data[n];
    int roi_start_w = round(offset_input_rois[0] * spatial_scale);
    int roi_start_h = round(offset_input_rois[1] * spatial_scale);
    int roi_end_w = round(offset_input_rois[2] * spatial_scale);
    int roi_end_h = round(offset_input_rois[3] * spatial_scale);
52 53 54 55

    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);

B
bug fix  
baiyfbupt 已提交
56 57 58 59 60 61 62 63 64 65 66 67
    int hstart = static_cast<int>(floor(static_cast<double>(ph) *
                                        static_cast<double>(roi_height) /
                                        static_cast<double>(pooled_height)));
    int wstart = static_cast<int>(floor(static_cast<double>(pw) *
                                        static_cast<double>(roi_width) /
                                        static_cast<double>(pooled_width)));
    int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
                                     static_cast<double>(roi_height) /
                                     static_cast<double>(pooled_height)));
    int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
                                     static_cast<double>(roi_width) /
                                     static_cast<double>(pooled_width)));
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
    int maxidx = -1;
    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_data_index = h * width + w;
        if (offset_input_data[input_data_index] > maxval) {
          maxval = offset_input_data[input_data_index];
          maxidx = input_data_index;
W
wanghaox 已提交
84
        }
W
wanghaox 已提交
85
      }
86
    }
B
baiyfbupt 已提交
87
    output_data[i] = maxval;
88
    if (argmax_data) {
B
baiyfbupt 已提交
89
      argmax_data[i] = maxidx;
W
wanghaox 已提交
90 91
    }
  }
92
}
W
wanghaox 已提交
93 94

template <typename T>
W
wanghaox 已提交
95
__global__ void GPUROIPoolBackward(
96 97 98
    const int nthreads, const int64_t* input_rois, const T* output_grad,
    const int64_t* argmax_data, const int num_rois, const float spatial_scale,
    const int channels, const int height, const int width,
99 100
    const int pooled_height, const int pooled_width, int* roi_batch_id_data,
    T* input_grad) {
101 102 103
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
104 105 106 107
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
108

109
    int roi_batch_ind = roi_batch_id_data[n];
110 111 112 113 114 115 116 117 118 119
    int input_offset = (roi_batch_ind * channels + c) * height * width;
    int output_offset = (n * channels + c) * pooled_height * pooled_width;
    const T* offset_output_grad = output_grad + output_offset;
    T* offset_input_grad = input_grad + input_offset;
    const int64_t* offset_argmax_data = argmax_data + output_offset;

    int argmax = offset_argmax_data[ph * pooled_width + pw];
    if (argmax != -1) {
      platform::CudaAtomicAdd(
          offset_input_grad + argmax,
W
wanghaox 已提交
120 121 122
          static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
    }
  }
123
}
W
wanghaox 已提交
124 125

template <typename Place, typename T>
W
wanghaox 已提交
126
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
127 128 129
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
130
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
131 132 133 134 135 136 137 138
    auto* out = ctx.Output<Tensor>("Out");
    auto* argmax = ctx.Output<Tensor>("Argmax");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
139
    int batch_size = in_dims[0];
W
wanghaox 已提交
140 141 142 143 144
    auto in_stride = framework::stride(in_dims);
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

145
    int rois_num = rois->dims()[0];
B
baiyfbupt 已提交
146

147
    if (rois_num == 0) return;
W
wanghaox 已提交
148 149

    int output_size = out->numel();
W
wanghaox 已提交
150 151
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;
W
wanghaox 已提交
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    framework::Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({rois_num});
    int* roi_batch_id_data =
        roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
    auto rois_lod = rois->lod().back();
    int rois_batch_size = rois_lod.size() - 1;
    PADDLE_ENFORCE_EQ(
        rois_batch_size, batch_size,
        "The rois_batch_size and imgs batch_size must be the same.");
    int rois_num_with_lod = rois_lod[rois_batch_size];
    PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
                      "The rois_num from input and lod must be the same.");
    for (int n = 0; n < rois_batch_size; ++n) {
      for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
        roi_batch_id_data[i] = n;
      }
    }

    framework::Tensor roi_batch_id_list_gpu;
    framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                          ctx.device_context(), &roi_batch_id_list_gpu);

175 176 177 178
    GPUROIPoolForward<
        T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
        output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
        channels, height, width, pooled_height, pooled_width,
179
        roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
180
        argmax->mutable_data<int64_t>(ctx.GetPlace()));
W
wanghaox 已提交
181 182 183 184
  }
};

template <typename Place, typename T>
W
wanghaox 已提交
185
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
186 187 188
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
189
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
190 191
    auto* argmax = ctx.Input<Tensor>("Argmax");

192 193
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
W
wanghaox 已提交
194 195 196 197 198

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

199
    int rois_num = rois->dims()[0];
W
wanghaox 已提交
200 201 202 203 204
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (x_grad) {
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
      framework::Tensor roi_batch_id_list;
      roi_batch_id_list.Resize({rois_num});
      int* roi_batch_id_data =
          roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          roi_batch_id_data[i] = n;
        }
      }
      framework::Tensor roi_batch_id_list_gpu;
      framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                            ctx.device_context(), &roi_batch_id_list_gpu);

W
wanghaox 已提交
220 221
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
Q
QI JUN 已提交
222
      set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
W
wanghaox 已提交
223 224

      int output_grad_size = out_grad->numel();
W
wanghaox 已提交
225 226
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;
W
wanghaox 已提交
227 228

      if (output_grad_size > 0) {
229 230 231 232 233
        GPUROIPoolBackward<
            T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
            output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
            argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
            width, pooled_height, pooled_width,
234
            roi_batch_id_list_gpu.data<int>(),
235 236
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
W
wanghaox 已提交
237 238 239 240 241 242 243 244
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
245 246 247 248 249
REGISTER_OP_CUDA_KERNEL(
    roi_pool,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
W
wanghaox 已提交
250
    roi_pool_grad,
Q
QI JUN 已提交
251 252
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);