prroi_pool_op.cu 16.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/prroi_pool_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaximumNumBlocks);
}

template <typename T>
32 33 34 35 36 37 38 39 40 41 42 43
__global__ void GPUPRROIPoolForward(const int nthreads,
                                    const T* input_data,
                                    const T* input_rois,
                                    const float spatial_scale,
                                    const int input_channels,
                                    const int height,
                                    const int width,
                                    const int output_channels,
                                    const int pooled_height,
                                    const int pooled_width,
                                    const int* rois_batch_id_data,
                                    T* output_data) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (size_t i = index; i < nthreads; i += offset) {
    // The output is in order (n, c, ph, pw)
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % output_channels;
    int n = i / pooled_width / pooled_height / output_channels;

    // set roi_batch_id
    int roi_batch_id = rois_batch_id_data[n];

    // [start, end) interval for spatial sampling
    const T* offset_input_rois = input_rois + n * 4;
    T roi_start_w = static_cast<T>(offset_input_rois[0]) * spatial_scale;
    T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
    T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
    T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;

    T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(0.0));
    T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(0.0));

    // Compute w and h at input feature map
    T bin_size_h = roi_height / static_cast<T>(pooled_height);
    T bin_size_w = roi_width / static_cast<T>(pooled_width);

    T win_start_w = roi_start_w + bin_size_w * pw;
    T win_start_h = roi_start_h + bin_size_h * ph;
    T win_end_w = win_start_w + bin_size_w;
    T win_end_h = win_start_h + bin_size_h;

    T win_size = max(static_cast<T>(0.0), bin_size_w * bin_size_h);
76
    int input_channel = c;
77 78 79 80 81 82 83 84 85 86 87 88 89 90
    const T* offset_input_data =
        input_data +
        (roi_batch_id * input_channels + input_channel) * height * width;

    if (win_size > static_cast<T>(0.0)) {
      int s_w = floor(win_start_w);
      int e_w = ceil(win_end_w);
      int s_h = floor(win_start_h);
      int e_h = ceil(win_end_h);
      T sum_out = 0;

      for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
        for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
          sum_out += PrRoIPoolingMatCalculation(
91 92 93 94 95
              offset_input_data,
              h_iter,
              w_iter,
              h_iter + 1,
              w_iter + 1,
96 97 98 99
              max(win_start_h, static_cast<T>(h_iter)),
              max(win_start_w, static_cast<T>(w_iter)),
              min(win_end_h, static_cast<T>(h_iter) + static_cast<T>(1.0)),
              min(win_end_w, static_cast<T>(w_iter) + static_cast<T>(1.0)),
100 101
              height,
              width);
102 103 104 105 106 107 108 109 110 111
        }
      }
      output_data[i] = sum_out / win_size;
    } else {
      output_data[i] = 0.;
    }
  }
}

template <typename T>
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
__global__ void GPUPRROIPoolBackward(const int nthreads,
                                     const T* in_data,
                                     const T* input_rois,
                                     const T* output_grad_data,
                                     const float spatial_scale,
                                     const int input_channels,
                                     const int height,
                                     const int width,
                                     const int output_channels,
                                     const int pooled_height,
                                     const int pooled_width,
                                     const int* rois_batch_id_data,
                                     T* input_grad_data,
                                     const T* out_data,
                                     T* input_roi_grad_data) {
127 128 129 130 131 132 133 134 135 136 137
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    // The output is in order (n, c, ph, pw)
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % output_channels;
    int n = i / pooled_width / pooled_height / output_channels;

    // set roi_batch_id
    int roi_batch_id = rois_batch_id_data[n];
138
    int input_channel = c;
139 140 141 142 143 144 145 146 147 148 149
    int input_offset =
        (roi_batch_id * input_channels + input_channel) * height * width;
    T* offset_input_grad_data = input_grad_data + input_offset;
    const T* offset_output_grad_data = output_grad_data + i;

    // [start, end) interval for spatial sampling
    const T* offset_input_rois = input_rois + n * 4;
    T roi_start_w = static_cast<T>(offset_input_rois[0]) * spatial_scale;
    T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
    T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
    T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;
150
    T* offset_input_roi_grad_data = input_roi_grad_data + n * 4;
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

    T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(0.0));
    T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(0.0));

    // Compute w and h at input feature map
    T bin_size_h = roi_height / static_cast<T>(pooled_height);
    T bin_size_w = roi_width / static_cast<T>(pooled_width);

    T win_start_w = roi_start_w + bin_size_w * pw;
    T win_start_h = roi_start_h + bin_size_h * ph;
    T win_end_w = win_start_w + bin_size_w;
    T win_end_h = win_start_h + bin_size_h;

    T win_size = max(static_cast<T>(0.0), bin_size_w * bin_size_h);
    int s_w = floor(win_start_w);
    int e_w = ceil(win_end_w);
    int s_h = floor(win_start_h);
    int e_h = ceil(win_end_h);

    T sum_out = win_size == static_cast<T>(0.)
                    ? static_cast<T>(0.)
                    : *offset_output_grad_data / win_size;

    for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
      for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
176
        PrRoIPoolingMatDistributeDiff<T>(
177 178 179 180 181 182 183
            offset_input_grad_data,
            sum_out,
            h_iter,
            w_iter,
            h_iter + 1,
            w_iter + 1,
            max(win_start_h, static_cast<T>(h_iter)),
184 185 186
            max(win_start_w, static_cast<T>(w_iter)),
            min(win_end_h, static_cast<T>(h_iter) + static_cast<T>(1.0)),
            min(win_end_w, static_cast<T>(w_iter) + static_cast<T>(1.0)),
187 188
            height,
            width);
189 190
      }
    }
191 192 193

    const T* offset_out_data = out_data + i;
    const T* offset_in_data = in_data + input_offset;
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    PrRoIPoolingCoorBackward<T>(s_w,
                                e_w,
                                s_h,
                                e_h,
                                width,
                                height,
                                win_start_w,
                                win_start_h,
                                win_end_w,
                                win_end_h,
                                pw,
                                ph,
                                pooled_width,
                                pooled_height,
                                win_size,
                                spatial_scale,
                                offset_in_data,
                                offset_out_data,
                                offset_input_roi_grad_data,
                                offset_output_grad_data);
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  }
}

template <typename T>
class GPUPRROIPoolOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
    auto* rois = ctx.Input<LoDTensor>("ROIs");
    auto* out = ctx.Output<Tensor>("Out");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    auto in_dims = in->dims();
    int batch_size = in_dims[0];
    int input_channels = in_dims[1];
232
    auto output_channels = input_channels;
233 234 235 236 237 238 239 240 241 242 243
    int height = in_dims[2];
    int width = in_dims[3];

    int rois_num = rois->dims()[0];
    if (rois_num == 0) return;

    // set rois batch id
    framework::Tensor rois_batch_id_list;
    rois_batch_id_list.Resize({rois_num});
    int* rois_batch_id_data =
        rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
244 245 246 247

    if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
      auto* batchroinum = ctx.Input<Tensor>("BatchRoINums");
      framework::Tensor batch_index_cpu;
248 249
      framework::TensorCopySync(
          *batchroinum, platform::CPUPlace(), &batch_index_cpu);
250 251 252 253 254 255 256 257 258

      int rois_batch_size = batchroinum->dims()[0];
      auto* batch_index = batch_index_cpu.data<int64_t>();
      size_t c = 0;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (int64_t k = 0; k < batch_index[n]; ++k) {
          rois_batch_id_data[c] = n;
          c = c + 1;
        }
259 260
      }

261 262 263 264
    } else {
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      PADDLE_ENFORCE_EQ(
265 266
          rois_batch_size,
          batch_size,
267 268 269 270
          platform::errors::InvalidArgument(
              "The rois_batch_size and input(X) batch_size must be the same."));
      int rois_num_with_lod = rois_lod[rois_batch_size];
      PADDLE_ENFORCE_EQ(
271 272
          rois_num,
          rois_num_with_lod,
273 274 275 276 277 278 279 280 281
          platform::errors::InvalidArgument(
              "The rois_num from input and lod must be the same."));

      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          rois_batch_id_data[i] = n;
        }
      }
    }
282 283 284 285 286

    int output_size = out->numel();
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;

287 288 289
    auto cplace = platform::CPUPlace();
    auto& dev_ctx = ctx.cuda_device_context();
    int bytes = rois_batch_id_list.numel() * sizeof(int);
290 291 292 293
    auto roi_ptr = memory::Alloc(
        dev_ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
294
    int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
295
    const auto gplace = ctx.GetPlace();
296 297 298 299 300
    memory::Copy(gplace,
                 roi_id_data,
                 cplace,
                 rois_batch_id_data,
                 bytes,
301 302
                 dev_ctx.stream());

303
    // call cuda kernel function
304
    GPUPRROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
305 306 307 308 309 310 311 312 313 314 315 316
        output_size,
        in->data<T>(),
        rois->data<T>(),
        spatial_scale,
        input_channels,
        height,
        width,
        output_channels,
        pooled_height,
        pooled_width,
        roi_id_data,
        out->mutable_data<T>(ctx.GetPlace()));
317 318 319 320 321 322 323 324 325
  }
};

template <typename DeviceContext, typename T>
class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
    auto* rois = ctx.Input<LoDTensor>("ROIs");
326
    auto* out = ctx.Input<framework::Tensor>("Out");
327 328 329

    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
330 331
    auto* input_roi_grad =
        ctx.Output<LoDTensor>(framework::GradVarName("ROIs"));
332 333 334 335 336 337 338

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");

    int rois_num = rois->dims()[0];
    int input_channels = in->dims()[1];
339
    auto output_channels = input_channels;
340 341 342
    int height = in->dims()[2];
    int width = in->dims()[3];

343
    if (input_grad || input_roi_grad) {
344 345 346 347 348
      // set roi batch id
      framework::Tensor rois_batch_id_list;
      rois_batch_id_list.Resize({rois_num});
      int* rois_batch_id_data =
          rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
349 350 351 352

      if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
        auto* batchroinum = ctx.Input<Tensor>("BatchRoINums");
        framework::Tensor batch_index_cpu;
353 354
        framework::TensorCopySync(
            *batchroinum, platform::CPUPlace(), &batch_index_cpu);
355 356 357 358 359 360 361 362 363 364 365

        int rois_batch_size = batchroinum->dims()[0];
        auto* batch_index = batch_index_cpu.data<int64_t>();
        size_t c = 0;
        for (int n = 0; n < rois_batch_size; ++n) {
          for (int64_t k = 0; k < batch_index[n]; ++k) {
            rois_batch_id_data[c] = n;
            c = c + 1;
          }
        }
      } else {
366 367
        PADDLE_ENFORCE_EQ(rois->lod().empty(),
                          false,
368
                          platform::errors::InvalidArgument(
T
tianshuo78520a 已提交
369
                              "the lod of Input ROIs should not be empty when "
370 371 372 373 374 375 376
                              "BatchRoINums is None!"));
        auto rois_lod = rois->lod().back();
        int rois_batch_size = rois_lod.size() - 1;
        for (int n = 0; n < rois_batch_size; ++n) {
          for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
            rois_batch_id_data[i] = n;
          }
377 378 379
        }
      }

380 381 382
      auto cplace = platform::CPUPlace();
      auto& dev_ctx = ctx.cuda_device_context();
      int bytes = rois_batch_id_list.numel() * sizeof(int);
383 384 385 386
      auto roi_ptr = memory::Alloc(
          dev_ctx.GetPlace(),
          bytes,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
387
      int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
388
      const auto gplace = ctx.GetPlace();
389 390 391 392 393
      memory::Copy(gplace,
                   roi_id_data,
                   cplace,
                   rois_batch_id_data,
                   bytes,
394
                   dev_ctx.stream());
395 396

      input_grad->mutable_data<T>(ctx.GetPlace());
397
      phi::funcs::SetConstant<DeviceContext, T> set_zero;
398
      set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
399 400
      input_roi_grad->mutable_data<T>(ctx.GetPlace());
      set_zero(ctx.cuda_device_context(), input_roi_grad, static_cast<T>(0));
401 402 403 404 405 406

      int output_grad_size = output_grad->numel();
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;

      if (output_grad_size > 0) {
407
        GPUPRROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
408 409 410 411 412 413 414 415 416 417 418 419 420 421
            output_grad_size,
            in->data<T>(),
            rois->data<T>(),
            output_grad->data<T>(),
            spatial_scale,
            input_channels,
            height,
            width,
            output_channels,
            pooled_height,
            pooled_width,
            roi_id_data,
            input_grad->mutable_data<T>(ctx.GetPlace()),
            out->data<T>(),
422
            input_roi_grad->mutable_data<T>(ctx.GetPlace()));
423 424 425 426 427 428 429 430 431
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
432 433
REGISTER_OP_CUDA_KERNEL(prroi_pool,
                        ops::GPUPRROIPoolOpKernel<float>,
434
                        ops::GPUPRROIPoolOpKernel<double>);
L
Leo Chen 已提交
435 436 437
REGISTER_OP_CUDA_KERNEL(prroi_pool_grad,
                        ops::GPUPRROIPoolGradOpKernel<phi::GPUContext, float>,
                        ops::GPUPRROIPoolGradOpKernel<phi::GPUContext, double>);