roi_align_op.cu 13.6 KB
Newer Older
J
jerrywgz 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

/*
template <class T>
inline __device__ T gpu_atomic_add(const T val, T* address) {
  return atomicAdd(address, val);
}
*/

template <class T>
__device__ T bilinear_interpolate(const T* input_data, const int height,
                                  const int width, T y, T x, ) {
  if (y < -1.0 || y > height || x < -1.0 || x > width) {
    return 0;
  }
  if (y <= 0) {
    y = 0;
  }
  if (x <= 0) {
    x = 0;
  }
  int y_low = static_cast<int>(y);
  int x_low = static_cast<int>(x);
  int y_high;
  int x_high;
  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = static_cast<T>(y_low);
  } else {
    y_high = y_low + 1;
  }
  if (x_low >= width - 1) {
    x_high = x_low = width - 1;
    x = static_cast<T>(x_low);
  } else {
    x_high = x_low + 1;
  }
  T ly = y - y_low, lx = x - x_low;
  T hy = 1. - ly, hx = 1. - lx;

  T v1 = input_data[y_low * width + x_low];
  T v2 = input_data[y_low * width + x_high];
  T v3 = input_data[y_high * width + x_low];
  T v4 = input_data[y_high * width + x_high];
  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

template <class T>
__device__ T bilinear_interpolate_gradient(const int height, const int width,
                                           T y, T x, const T& w1, const T& w2,
                                           const T& w3, const T& w4,
                                           const int& x_low, const int& x_high,
                                           const int& y_low,
                                           const int& y_high) {
  if (y < -1.0 || y > height || x < -1.0 || x > width) {
    w1 = w2 = w3 = w4 = 0.;
    x_low = x_high = y_low = y_high = -1;
    return;
  }

  if (y <= 0) {
    y = 0;
  }
  if (x <= 0) {
    x = 0;
  }
  y_low = static_cast<int>(y);
  x_low = static_cast<int>(x);
  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = static_cast<T>(y_low);
  } else {
    y_high = y_low + 1;
  }
  if (x_low >= width - 1) {
    x_high = x_low = width - 1;
    x = static_cast<T>(x_low);
  } else {
    x_high = x_low + 1;
  }
  T ly = y - y_low, lx = x - x_low;
  T hy = 1. - ly, hx = 1. - lx;
  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

  return;
}

template <class T>
__global__ void GPUROIAlignForward(
    const int nthreads, const T* input_data, const T* input_rois,
    const float spatial_scale, const int channels, const int height,
    const int width, const int pooled_height, const int pooled_width,
    const int sampling_ratio int* roi_batch_id_data, T* output_data) {
  CUDA_1D_KERNEL_LOOP(i, nthreads) {
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (i / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;

    const T* offset_input_rois = input_rois + n * kROISize;
    int roi_batch_ind = roi_batch_id_data[n];

    T roi_xmin = offset_input_rois[0] * spatial_scale;
    T roi_ymin = offset_input_rois[1] * spatial_scale;
    T roi_xmax = offset_input_rois[2] * spatial_scale;
    T roi_ymax = offset_input_rois[3] * spatial_scale;

    T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
    T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    const T* offset_input_data =
        input_data + (roi_batch_ind * channels + c) * height * width;

    int roi_bin_grid_h = (sampling_ratio > 0)
                             ? sampling_ratio
                             : ceil(roi_height / pooled_height);
    int roi_bin_grid_w =
        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
    const T count = roi_bin_grid_h * roi_bin_grid_w;
    T output_val = 0;
    for (int iy = 0; iy < roi_bin_grid_h; iy++) {
      const T y = roi_ymin + ph * bin_size_h +
                  static_cast<T>(iy + .5f) * bin_size_h /
                      static_cast<T>(roi_bin_grid_h);
      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
        const T x = roi_xmin + pw * bin_size_w +
                    static_cast<T>(ix + .5f) * bin_size_w /
                        static_cast<T>(roi_bin_grid_w);
        T val = bilinear_interpolate(offset_input_data, height, width, y, x);
        output_val += val;
      }
    }
    output_val /= count;
    output_data[i] = output_val;
  }
}

template <typename T>
__global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
                                    const T* output_grad, const int num_rois,
                                    const float spatial_scale,
                                    const int channels, const int height,
                                    const int width, const int pooled_height,
                                    const int pooled_width,
                                    const int sampling_ratio,
                                    int* roi_batch_id_data, T* input_grad) {
  CUDA_1D_KERNEL_LOOP(i, nthreads) {
    int pw = i % pooled_width;
    int ph = (i / pooled_width) % pooled_height;
    int c = (ic / pooled_width / pooled_height) % channels;
    int n = i / pooled_width / pooled_height / channels;
    const T* offset_input_rois = input_rois + n * kROISize;
    int roi_batch_ind = roi_batch_id_data[n];

    T roi_xmin = offset_input_rois[0] * spatial_scale;
    T roi_ymin = offset_input_rois[1] * spatial_scale;
    T roi_xmax = offset_input_rois[2] * spatial_scale;
    T roi_ymax = offset_input_rois[3] * spatial_scale;

    T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
    T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    const T* offset_input_grad =
        input_grad + (roi_batch_ind * channels + c) * height * width;

    const T* offset_out_grad =
        out_grad + (n * channels + c) * pooled_height * pooled_width;
    const T out_grad_this_bin = offset_out_grad[ph * pooled_width + pw];

    int roi_bin_grid_h = (sampling_ratio > 0)
                             ? sampling_ratio
                             : ceil(roi_height / pooled_height);
    int roi_bin_grid_w =
        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

    const T count = roi_bin_grid_h * roi_bin_grid_w;
    for (int iy = 0; iy < roi_bin_grid_h; iy++) {
      const T y = roi_start_h + ph * bin_size_h +
                  static_cast<T>(iy + .5f) * bin_size_h /
                      static_cast<T>(roi_bin_grid_h);
      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
        const T x = roi_start_w + pw * bin_size_w +
                    static_cast<T>(ix + .5f) * bin_size_w /
                        static_cast<T>(roi_bin_grid_w);
        T w1, w2, w3, w4;
        int x_low, x_high, y_low, y_high;
        bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
                                      x_low, x_high, y_low, y_high);
        T diff1 = out_grad_this_bin * w1 / count;
        T diff2 = out_grad_this_bin * w2 / count;
        T diff3 = out_grad_this_bin * w3 / count;
        T diff4 = out_grad_this_bin * w4 / count;
        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
          platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_low,
                                  diff1);
          platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_high,
                                  diff2);
          platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_low,
                                  diff3);
          platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_high,
                                  diff3);
        }
      }
    }
  }
}

template <typename Place, typename T>
class GPUROIAlignOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    i auto* in = ctx.Input<Tensor>("X");
    auto* rois = ctx.Input<LoDTensor>("ROIs");
    auto* out = ctx.Output<Tensor>("Out");

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");
    auto sampling_ratio = ctx.Attr<int>("sampling_ratio");

    auto in_dims = in->dims();
    int batch_size = in_dims[0];
    int channels = in_dims[1];
    int height = in_dims[2];
    int width = in_dims[3];

    int rois_num = rois->dims()[0];

    if (rois_num == 0) return;

    int output_size = out->numel();
    int blocks = NumBlocks(output_size);
    int threads = kNumCUDAThreads;

    Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({rois_num});
    int* roi_batch_id_data =
        roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
    auto rois_lod = rois->lod().back();
    int rois_batch_size = rois_lod.size() - 1;
    PADDLE_ENFORCE_EQ(
        rois_batch_size, batch_size,
        "The rois_batch_size and imgs batch_size must be the same.");
    int rois_num_with_lod = rois_lod[rois_batch_size];
    PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
                      "The rois_num from input and lod must be the same.");
    for (int n = 0; n < rois_batch_size; ++n) {
      for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
        roi_batch_id_data[i] = n;
      }
    }
    Tensor roi_batch_id_list_gpu;
    framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                          ctx.device_context(), &roi_batch_id_list_gpu);
    GPUROIAlignForward<
        T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
        output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
        height, width, pooled_height, pooled_width, sampling_ratio,
        roi_batch_id_list_gpu.data<int>(),
        out->mutable_data<T>(ctx.GetPlace()));
  }
};

template <typename Place, typename T>
class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<Tensor>("X");
    auto* rois = ctx.Input<LoDTensor>("ROIs");

    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* in_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

    auto pooled_height = ctx.Attr<int>("pooled_height");
    auto pooled_width = ctx.Attr<int>("pooled_width");
    auto spatial_scale = ctx.Attr<float>("spatial_scale");
    auto sampling_ratio = ctx.Attr<int>("sampling_ratio");

    int rois_num = rois->dims()[0];
    int channels = in->dims()[1];
    int height = in->dims()[2];
    int width = in->dims()[3];

    if (in_grad) {
      Tensor roi_batch_id_list;
      roi_batch_id_list.Resize({rois_num});
      int* roi_batch_id_data =
          roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
      auto rois_lod = rois->lod().back();
      int rois_batch_size = rois_lod.size() - 1;
      for (int n = 0; n < rois_batch_size; ++n) {
        for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
          roi_batch_id_data[i] = n;
        }
      }
      Tensor roi_batch_id_list_gpu;
      framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
                            ctx.device_context(), &roi_batch_id_list_gpu);

      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
      set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));

      int output_grad_size = out_grad->numel();
      int blocks = NumBlocks(output_grad_size);
      int threads = kNumCUDAThreads;

      if (output_grad_size > 0) {
        GPUROIAlignBackward<
            T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
            output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
            spatial_scale, channels, height, width, pooled_height, pooled_width,
            sampling_ratio, roi_batch_id_list_gpu.data<int>(),
            x_grad->mutable_data<T>(ctx.GetPlace()));
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    roi_align,
    ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    roi_align_grad,
    ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, double>);