roi_pool_op.cu 12.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
F
FDInSky 已提交
14
#include <vector>
15
#include "paddle/fluid/memory/memory.h"
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/roi_pool_op.h"
D
dzhwinter 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaox 已提交
18 19 20 21

namespace paddle {
namespace operators {

W
wanghaox 已提交
22
using Tensor = framework::Tensor;
23
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
24

W
wanghaox 已提交
25 26
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
W
wanghaox 已提交
27

W
wanghaox 已提交
28 29 30
static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
W
wanghaox 已提交
31
}
W
wanghaox 已提交
32

33
template <typename T>
34
__global__ void GPUROIPoolForward(
35
    const int nthreads, const T* input_data, const T* input_rois,
36 37 38
    const float spatial_scale, const int channels, const int height,
    const int width, const int pooled_height, const int pooled_width,
    int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
39 40 41
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
42 43 44 45
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
46

47
    const T* offset_input_rois = input_rois + n * kROISize;
48 49 50 51 52
    int roi_batch_ind = roi_batch_id_data[n];
    int roi_start_w = round(offset_input_rois[0] * spatial_scale);
    int roi_start_h = round(offset_input_rois[1] * spatial_scale);
    int roi_end_w = round(offset_input_rois[2] * spatial_scale);
    int roi_end_h = round(offset_input_rois[3] * spatial_scale);
53 54 55 56

    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);

B
bug fix  
baiyfbupt 已提交
57 58 59 60 61 62 63 64 65 66 67 68
    int hstart = static_cast<int>(floor(static_cast<double>(ph) *
                                        static_cast<double>(roi_height) /
                                        static_cast<double>(pooled_height)));
    int wstart = static_cast<int>(floor(static_cast<double>(pw) *
                                        static_cast<double>(roi_width) /
                                        static_cast<double>(pooled_width)));
    int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
                                     static_cast<double>(roi_height) /
                                     static_cast<double>(pooled_height)));
    int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
                                     static_cast<double>(roi_width) /
                                     static_cast<double>(pooled_width)));
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
    int maxidx = -1;
    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_data_index = h * width + w;
        if (offset_input_data[input_data_index] > maxval) {
          maxval = offset_input_data[input_data_index];
          maxidx = input_data_index;
W
wanghaox 已提交
85
        }
W
wanghaox 已提交
86
      }
87
    }
B
baiyfbupt 已提交
88
    output_data[i] = maxval;
89
    if (argmax_data) {
B
baiyfbupt 已提交
90
      argmax_data[i] = maxidx;
W
wanghaox 已提交
91 92
    }
  }
93
}
W
wanghaox 已提交
94 95

template <typename T>
W
wanghaox 已提交
96
__global__ void GPUROIPoolBackward(
97
    const int nthreads, const T* input_rois, const T* output_grad,
98 99
    const int64_t* argmax_data, const int num_rois, const float spatial_scale,
    const int channels, const int height, const int width,
100 101
    const int pooled_height, const int pooled_width, int* roi_batch_id_data,
    T* input_grad) {
102 103 104
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
B
baiyfbupt 已提交
105 106 107 108
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
109

110
    int roi_batch_ind = roi_batch_id_data[n];
111 112 113 114 115 116 117 118 119 120
    int input_offset = (roi_batch_ind * channels + c) * height * width;
    int output_offset = (n * channels + c) * pooled_height * pooled_width;
    const T* offset_output_grad = output_grad + output_offset;
    T* offset_input_grad = input_grad + input_offset;
    const int64_t* offset_argmax_data = argmax_data + output_offset;

    int argmax = offset_argmax_data[ph * pooled_width + pw];
    if (argmax != -1) {
      platform::CudaAtomicAdd(
          offset_input_grad + argmax,
W
wanghaox 已提交
121 122 123
          static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
    }
  }
124
}
W
wanghaox 已提交
125 126

template <typename Place, typename T>
W
wanghaox 已提交
127
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
128 129 130
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
131
    auto* rois = ctx.Input<LoDTensor>("ROIs");
W
wanghaox 已提交
132 133 134 135 136 137 138 139
    auto* out = ctx.Output<Tensor>("Out");
    auto* argmax = ctx.Output<Tensor>("Argmax");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
140
    int batch_size = in_dims[0];
W
wanghaox 已提交
141 142 143 144 145
    auto in_stride = framework::stride(in_dims);
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

146
    int rois_num = rois->dims()[0];
B
baiyfbupt 已提交
147

148
    if (rois_num == 0) return;
W
wanghaox 已提交
149 150

    int output_size = out->numel();
W
wanghaox 已提交
151 152
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;
W
wanghaox 已提交
153

154 155
    framework::Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({rois_num});
156 157
    auto cplace = platform::CPUPlace();
    int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
F
FDInSky 已提交
158
    auto& dev_ctx = ctx.cuda_device_context();
159
    auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
160 161 162
    if (ctx.HasInput("RoisNum")) {
      auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
      int rois_batch_size = rois_num_t->numel();
163

F
FDInSky 已提交
164
      PADDLE_ENFORCE_EQ(
165
          rois_batch_size, batch_size,
166 167 168 169 170
          platform::errors::InvalidArgument(
              "The batch size of input(ROIs) and input(X) must be the same but "
              "received batch size of input(ROIs) and input(X) is %d and %d "
              "respectively.",
              rois_batch_size, batch_size));
171 172 173 174 175 176
      std::vector<int> rois_num_list(rois_batch_size);
      memory::Copy(cplace, rois_num_list.data(), gplace,
                   rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
      int start = 0;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (int i = start; i < start + rois_num_list[n]; ++i) {
F
FDInSky 已提交
177 178
          roi_batch_id_data[i] = n;
        }
179
        start += rois_num_list[n];
F
FDInSky 已提交
180 181 182 183 184 185
      }
    } else {
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      PADDLE_ENFORCE_EQ(
          rois_batch_size, batch_size,
186 187 188 189 190 191
          platform::errors::InvalidArgument(
              "The batch size of input(ROIs) and input(X) must be the same but "
              "received batch size of input(ROIs) and input(X) is %d and %d "
              "respectively.",
              rois_batch_size, batch_size));

F
FDInSky 已提交
192 193
      int rois_num_with_lod = rois_lod[rois_batch_size];
      PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
194 195 196 197 198
                        platform::errors::InvalidArgument(
                            "The number of rois from input(ROIs) and its LOD "
                            "must be the same. Received rois %d of input(ROIs) "
                            "but the number of rois %d from its LOD is %d",
                            rois_num, rois_num_with_lod));
F
FDInSky 已提交
199 200 201 202
      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          roi_batch_id_data[i] = n;
        }
203 204
      }
    }
205
    int bytes = roi_batch_id_list.numel() * sizeof(int);
206
    auto roi_ptr = memory::Alloc(dev_ctx, bytes);
207 208 209 210 211
    int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
    memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
                 dev_ctx.stream());

    GPUROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
212
        output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
213 214
        height, width, pooled_height, pooled_width, roi_id_data,
        out->mutable_data<T>(ctx.GetPlace()),
215
        argmax->mutable_data<int64_t>(ctx.GetPlace()));
W
wanghaox 已提交
216 217 218 219
  }
};

template <typename Place, typename T>
W
wanghaox 已提交
220
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
W
wanghaox 已提交
221 222 223
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
224
    auto* rois = ctx.Input<LoDTensor>("ROIs");
225
    auto* rois_lod = ctx.Input<Tensor>("RoisNum");
W
wanghaox 已提交
226 227
    auto* argmax = ctx.Input<Tensor>("Argmax");

228 229
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
W
wanghaox 已提交
230 231 232 233 234

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

235
    int rois_num = rois->dims()[0];
W
wanghaox 已提交
236 237 238 239 240
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (x_grad) {
241 242
      framework::Tensor roi_batch_id_list;
      roi_batch_id_list.Resize({rois_num});
243 244 245 246
      auto cplace = platform::CPUPlace();
      int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);

      auto& dev_ctx = ctx.cuda_device_context();
247
      auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
248 249 250 251 252 253 254 255 256
      if (ctx.HasInput("RoisNum")) {
        auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
        int rois_batch_size = rois_num_t->numel();
        std::vector<int> rois_num_list(rois_batch_size);
        memory::Copy(cplace, rois_num_list.data(), gplace,
                     rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
        int start = 0;
        for (int n = 0; n < rois_batch_size; ++n) {
          for (int i = start; i < start + rois_num_list[n]; ++i) {
F
FDInSky 已提交
257 258
            roi_batch_id_data[i] = n;
          }
259
          start += rois_num_list[n];
F
FDInSky 已提交
260 261 262 263 264 265 266 267 268 269
        }
      } else {
        auto rois_lod = rois->lod().back();
        int rois_batch_size = rois_lod.size() - 1;
        for (int n = 0; n < rois_batch_size; ++n) {
          for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
            roi_batch_id_data[i] = n;
          }
        }
      }
270
      int bytes = roi_batch_id_list.numel() * sizeof(int);
271
      auto roi_ptr = memory::Alloc(dev_ctx, bytes);
272 273 274
      int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
      memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
                   dev_ctx.stream());
275

W
wanghaox 已提交
276 277
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
278
      set_zero(dev_ctx, x_grad, static_cast<T>(0));
W
wanghaox 已提交
279 280

      int output_grad_size = out_grad->numel();
W
wanghaox 已提交
281 282
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;
W
wanghaox 已提交
283 284

      if (output_grad_size > 0) {
285
        GPUROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
286
            output_grad_size, rois->data<T>(), out_grad->data<T>(),
287
            argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
288
            width, pooled_height, pooled_width, roi_id_data,
289 290
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
W
wanghaox 已提交
291 292 293 294 295 296 297 298
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
299 300 301 302 303
REGISTER_OP_CUDA_KERNEL(
    roi_pool,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIPoolOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
W
wanghaox 已提交
304
    roi_pool_grad,
Q
QI JUN 已提交
305
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, float>,
J
jerrywgz 已提交
306
    ops::GPUROIPoolGradOpKernel<paddle::platform::CUDADeviceContext, double>);