interpolate_op.cu 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using framework::Tensor;

template <typename T>
__global__ void KeNearestNeighborInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
    const size_t num_channels, const float ratio_h, const float ratio_w) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
29 30
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
    int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);

    int out_img_idx = tid % out_img_w;
    int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);

    out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
                  in_img_idy * in_img_w + in_img_idx];
  }
}

template <typename T>
__global__ void KeNearestNeighborInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
    const size_t num_channels, const float ratio_h, const float ratio_w) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
56 57
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
    int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);

    int out_img_idx = tid % out_img_w;
    int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T out_pos = out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(in_pos, out_pos);
  }
}

template <typename T>
__global__ void KeBilinearInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
    const size_t num_channels, const float ratio_h, const float ratio_w) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
85 86
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
    int in_img_idy = ratio_h * out_img_idy;
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
    T h1lambda = ratio_h * out_img_idy - in_img_idy;
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
    int in_img_idx = ratio_w * out_img_idx;
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
    T w1lambda = ratio_w * out_img_idx - in_img_idx;
    T w2lambda = 1.f - w1lambda;

    const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                          in_img_idy * in_img_w + in_img_idx];

    // bilinear interpolation
    out[out_id_h * output_w + out_id_w] =
        h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
        h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
                    w1lambda * in_pos[h_id * in_img_w + w_id]);
  }
}

template <typename T>
__global__ void KeBilinearInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
    const size_t num_channels, const T ratio_h, const T ratio_w) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
124 125
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
    int in_img_idy = ratio_h * out_img_idy;
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
    T h1lambda = ratio_h * out_img_idy - in_img_idy;
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
    int in_img_idx = ratio_w * out_img_idx;
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
    T w1lambda = ratio_w * out_img_idx - in_img_idx;
    T w2lambda = 1.f - w1lambda;

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T* out_pos = &out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
                            h1lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
                            h1lambda * w1lambda * out_pos[0]);
  }
}

template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* input_data = input->data<T>();

    auto interp_method = ctx.Attr<std::string>("interp_method");
    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      Tensor sizes;
      framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
      auto size_data = sizes.data<int>();
      out_h = size_data[0];
      out_w = size_data[1];
    }

    int n = input->dims()[0];
    int c = input->dims()[1];
    int in_h = input->dims()[2];
    int in_w = input->dims()[3];

    auto* output_data =
        output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());

    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = c * in_hw;
    int out_chw = c * out_hw;

    float ratio_h =
        (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
    float ratio_w =
        (out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;

    if (in_h == out_h && in_w == out_w) {
      framework::TensorCopy(*input, ctx.GetPlace(), output);
      return;
    }

201 202 203
    int pixelNum = n * out_chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;
204 205 206

    if ("nearest" == interp_method) {
      KeNearestNeighborInterpFw<
207
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
208 209 210 211
          input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
          out_chw, c, ratio_h, ratio_w);
    } else if ("bilinear" == interp_method) {
      KeBilinearInterpFw<
212
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
          input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
          out_chw, c, ratio_h, ratio_w);
    }
  }
};

template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* output_grad_data = output_grad->data<T>();
    auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());

    auto& device_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();
    math::SetConstant<platform::CUDADeviceContext, T> zero;
    zero(device_ctx, input_grad, static_cast<T>(0.0));

    auto interp_method = ctx.Attr<std::string>("interp_method");
    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      Tensor sizes;
      framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
      auto size_data = sizes.data<int>();
      out_h = size_data[0];
      out_w = size_data[1];
    }

    int n = input_grad->dims()[0];
    int c = input_grad->dims()[1];
    int in_h = input_grad->dims()[2];
    int in_w = input_grad->dims()[3];

    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = c * in_hw;
    int out_chw = c * out_hw;

    float ratio_h =
        (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
    float ratio_w =
        (out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;

    if (in_h == out_h && in_w == out_w) {
      framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
      return;
    }

265 266 267
    int pixelNum = n * out_chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;
268 269 270

    if ("nearest" == interp_method) {
      KeNearestNeighborInterpBw<
271
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
272 273 274 275
          input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
          out_w, n, out_chw, c, ratio_h, ratio_w);
    } else if ("bilinear" == interp_method) {
      KeBilinearInterpBw<
276
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
          input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
          out_w, n, out_chw, c, ratio_h, ratio_w);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel<float>,
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(interpolate_grad,
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);