roi_pool_op.cu 10.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/memory/memory.h"
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/roi_pool_op.h"
D
dzhwinter 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaox 已提交
18 19 20 21

namespace paddle {
namespace operators {

W
wanghaox 已提交
22
using Tensor = framework::Tensor;
23
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
24

W
wanghaox 已提交
25 26
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
W
wanghaox 已提交
27

W
wanghaox 已提交
28 29 30
static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
W
wanghaox 已提交
31
}
W
wanghaox 已提交
32

33
template <typename T>
34
__global__ void GPUROIPoolForward(
35
    const int nthreads, const T* input_data, const T* input_rois,
36 37 38
    const float spatial_scale, const int channels, const int height,
    const int width, const int pooled_height, const int pooled_width,
    int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
39 40 41
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
42 43 44 45
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
46

47
    const T* offset_input_rois = input_rois + n * kROISize;
48 49 50 51 52
    int roi_batch_ind = roi_batch_id_data[n];
    int roi_start_w = round(offset_input_rois[0] * spatial_scale);
    int roi_start_h = round(offset_input_rois[1] * spatial_scale);
    int roi_end_w = round(offset_input_rois[2] * spatial_scale);
    int roi_end_h = round(offset_input_rois[3] * spatial_scale);
53 54 55 56

    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);

B
bug fix  
baiyfbupt 已提交
57 58 59 60 61 62 63 64 65 66 67 68
    int hstart = static_cast<int>(floor(static_cast<double>(ph) *
                                        static_cast<double>(roi_height) /
                                        static_cast<double>(pooled_height)));
    int wstart = static_cast<int>(floor(static_cast<double>(pw) *
                                        static_cast<double>(roi_width) /
                                        static_cast<double>(pooled_width)));
    int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
                                     static_cast<double>(roi_height) /
                                     static_cast<double>(pooled_height)));
    int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
                                     static_cast<double>(roi_width) /
                                     static_cast<double>(pooled_width)));
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
    int maxidx = -1;
    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_data_index = h * width + w;
        if (offset_input_data[input_data_index] > maxval) {
          maxval = offset_input_data[input_data_index];
          maxidx = input_data_index;
W
wanghaox 已提交
85
        }
W
wanghaox 已提交
86
      }
87
    }
B
baiyfbupt 已提交
88
    output_data[i] = maxval;
89
    if (argmax_data) {
B
baiyfbupt 已提交
90
      argmax_data[i] = maxidx;
W
wanghaox 已提交
91 92
    }
  }
93
}
W
wanghaox 已提交
94 95

template <typename T>
W
wanghaox 已提交
96
__global__ void GPUROIPoolBackward(
97
    const int nthreads, const T* input_rois, const T* output_grad,
98 99
    const int64_t* argmax_data, const int num_rois, const float spatial_scale,
    const int channels, const int height, const int width,
100 101
    const int pooled_height, const int pooled_width, int* roi_batch_id_data,
    T* input_grad) {
102 103 104
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
105 106 107 108
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
109

110
    int roi_batch_ind = roi_batch_id_data[n];
111 112 113 114 115 116 117 118 119 120
    int input_offset = (roi_batch_ind * channels + c) * height * width;
    int output_offset = (n * channels + c) * pooled_height * pooled_width;
    const T* offset_output_grad = output_grad + output_offset;
    T* offset_input_grad = input_grad + input_offset;
    const int64_t* offset_argmax_data = argmax_data + output_offset;

    int argmax = offset_argmax_data[ph * pooled_width + pw];
    if (argmax != -1) {
      platform::CudaAtomicAdd(
          offset_input_grad + argmax,
W
wanghaox 已提交
121 122 123
          static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
    }
  }
124
}
W
wanghaox 已提交
125 126

template <typename Place, typename T>
W
wanghaox 已提交
127
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
128 129 130
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
131
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
132 133 134 135 136 137 138 139
    auto* out = ctx.Output<Tensor>("Out");
    auto* argmax = ctx.Output<Tensor>("Argmax");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
140
    int batch_size = in_dims[0];
W
wanghaox 已提交
141 142 143 144 145
    auto in_stride = framework::stride(in_dims);
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

146
    int rois_num = rois->dims()[0];
B
baiyfbupt 已提交
147

148
    if (rois_num == 0) return;
W
wanghaox 已提交
149 150

    int output_size = out->numel();
W
wanghaox 已提交
151 152
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;
W
wanghaox 已提交
153

154 155
    framework::Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({rois_num});
156 157
    auto cplace = platform::CPUPlace();
    int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
158 159 160 161 162 163 164 165 166 167 168 169 170 171
    auto rois_lod = rois->lod().back();
    int rois_batch_size = rois_lod.size() - 1;
    PADDLE_ENFORCE_EQ(
        rois_batch_size, batch_size,
        "The rois_batch_size and imgs batch_size must be the same.");
    int rois_num_with_lod = rois_lod[rois_batch_size];
    PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
                      "The rois_num from input and lod must be the same.");
    for (int n = 0; n < rois_batch_size; ++n) {
      for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
        roi_batch_id_data[i] = n;
      }
    }

172 173
    auto& dev_ctx = ctx.cuda_device_context();
    int bytes = roi_batch_id_list.numel() * sizeof(int);
174
    auto roi_ptr = memory::Alloc(dev_ctx, bytes);
175 176 177 178 179 180
    int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
    const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
    memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
                 dev_ctx.stream());

    GPUROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
181
        output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
182 183
        height, width, pooled_height, pooled_width, roi_id_data,
        out->mutable_data<T>(ctx.GetPlace()),
184
        argmax->mutable_data<int64_t>(ctx.GetPlace()));
W
wanghaox 已提交
185 186 187 188
  }
};

template <typename Place, typename T>
W
wanghaox 已提交
189
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
190 191 192
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
193
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
194 195
    auto* argmax = ctx.Input<Tensor>("Argmax");

196 197
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
W
wanghaox 已提交
198 199 200 201 202

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

203
    int rois_num = rois->dims()[0];
W
wanghaox 已提交
204 205 206 207 208
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (x_grad) {
209 210
      framework::Tensor roi_batch_id_list;
      roi_batch_id_list.Resize({rois_num});
211 212
      auto cplace = platform::CPUPlace();
      int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
213 214 215 216 217 218 219
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          roi_batch_id_data[i] = n;
        }
      }
220 221 222

      auto& dev_ctx = ctx.cuda_device_context();
      int bytes = roi_batch_id_list.numel() * sizeof(int);
223
      auto roi_ptr = memory::Alloc(dev_ctx, bytes);
224 225 226 227
      int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
      const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
      memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
                   dev_ctx.stream());
228

W
wanghaox 已提交
229 230
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
231
      set_zero(dev_ctx, x_grad, static_cast<T>(0));
W
wanghaox 已提交
232 233

      int output_grad_size = out_grad->numel();
W
wanghaox 已提交
234 235
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;
W
wanghaox 已提交
236 237

      if (output_grad_size > 0) {
238
        GPUROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
239
            output_grad_size, rois->data<T>(), out_grad->data<T>(),
240
            argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
241
            width, pooled_height, pooled_width, roi_id_data,
242 243
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
W
wanghaox 已提交
244 245 246 247 248 249 250 251
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
252 253 254 255 256
REGISTER_OP_CUDA_KERNEL(
    roi_pool,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
W
wanghaox 已提交
257
    roi_pool_grad,
Q
QI JUN 已提交
258
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
J
jerrywgz 已提交
259
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, double>);