interpolate_op.cu 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using framework::Tensor;

template <typename T>
__global__ void KeNearestNeighborInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
26 27
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners) {
28 29
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
30 31
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
32 33 34 35 36 37 38
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
39 40 41
    int in_img_idy = (align_corners)
                         ? static_cast<int>(ratio_h * out_img_idy + 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
42 43

    int out_img_idx = tid % out_img_w;
44 45 46
    int in_img_idx = (align_corners)
                         ? static_cast<int>(ratio_w * out_img_idx + 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
47 48 49 50 51 52 53 54 55 56 57

    out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
                  in_img_idy * in_img_w + in_img_idx];
  }
}

template <typename T>
__global__ void KeNearestNeighborInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
58 59
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners) {
60 61
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
62 63
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
64 65 66 67 68 69 70
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
71 72 73
    int in_img_idy = (align_corners)
                         ? static_cast<int>(ratio_h * out_img_idy + 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
74 75

    int out_img_idx = tid % out_img_w;
76 77 78
    int in_img_idx = (align_corners)
                         ? static_cast<int>(ratio_w * out_img_idx + 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
79 80 81 82 83 84 85 86 87 88 89 90 91

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T out_pos = out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(in_pos, out_pos);
  }
}

template <typename T>
__global__ void KeBilinearInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
92 93
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners, const int align_mode) {
94 95
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
96
  int stride = blockDim.x * gridDim.x;
T
tink2123 已提交
97
  bool align_flag = (align_mode == 0 && !align_corners);
98
  for (; tid < nthreads; tid += stride) {
99 100 101 102 103 104 105
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
T
tink2123 已提交
106
    int in_img_idy = align_flag
107 108
                         ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
T
tink2123 已提交
109
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
110
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T
tink2123 已提交
111 112
    T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
                            : ratio_h * out_img_idy - in_img_idy;
113 114 115
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
T
tink2123 已提交
116
    int in_img_idx = align_flag
117 118
                         ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
T
tink2123 已提交
119
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
120
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T
tink2123 已提交
121 122
    T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
                            : ratio_w * out_img_idx - in_img_idx;
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    T w2lambda = 1.f - w1lambda;

    const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                          in_img_idy * in_img_w + in_img_idx];

    // bilinear interpolation
    out[out_id_h * output_w + out_id_w] =
        h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
        h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
                    w1lambda * in_pos[h_id * in_img_w + w_id]);
  }
}

template <typename T>
__global__ void KeBilinearInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
141 142
    const size_t num_channels, const T ratio_h, const T ratio_w,
    const bool align_corners, const int align_mode) {
143 144
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
145
  int stride = blockDim.x * gridDim.x;
T
tink2123 已提交
146
  bool align_flag = (align_mode == 0 && !align_corners);
147
  for (; tid < nthreads; tid += stride) {
148 149 150 151 152 153 154
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
T
tink2123 已提交
155 156
    int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5
                                : ratio_h * out_img_idy;
T
tink2123 已提交
157
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
158
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T
tink2123 已提交
159 160
    T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
                            : ratio_h * out_img_idy - in_img_idy;
161

162 163 164
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
T
tink2123 已提交
165 166
    int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
                                : ratio_w * out_img_idx;
T
tink2123 已提交
167
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
168
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T
tink2123 已提交
169 170
    T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
                            : ratio_w * out_img_idx - in_img_idx;
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    T w2lambda = 1.f - w1lambda;

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T* out_pos = &out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
                            h1lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
                            h1lambda * w1lambda * out_pos[0]);
  }
}

template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* input_data = input->data<T>();

D
dengkaipeng 已提交
195 196 197 198 199
    int n = input->dims()[0];
    int c = input->dims()[1];
    int in_h = input->dims()[2];
    int in_w = input->dims()[3];

200 201 202
    auto interp_method = ctx.Attr<std::string>("interp_method");
    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
D
dengkaipeng 已提交
203 204 205 206 207 208 209

    float scale = ctx.Attr<float>("scale");
    if (scale > 0) {
      out_h = in_h * scale;
      out_w = in_w * scale;
    }

210 211 212 213 214 215 216 217 218
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      Tensor sizes;
      framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
      auto size_data = sizes.data<int>();
      out_h = size_data[0];
      out_w = size_data[1];
    }

219 220 221
    bool align_corners = ctx.Attr<bool>("align_corners");
    int align_mode = ctx.Attr<int>("align_mode");

222 223 224 225 226 227 228 229
    auto* output_data =
        output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());

    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = c * in_hw;
    int out_chw = c * out_hw;

T
tink2123 已提交
230 231 232 233 234 235 236
    float ratio_h = 0.f;
    float ratio_w = 0.f;
    if (out_h > 1) {
      ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                                : static_cast<float>(in_h) / out_h;
    }
    if (out_w > 1) {
T
tink2123 已提交
237 238
      ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                                : static_cast<float>(in_w) / out_w;
T
tink2123 已提交
239
    }
240 241 242 243 244 245

    if (in_h == out_h && in_w == out_w) {
      framework::TensorCopy(*input, ctx.GetPlace(), output);
      return;
    }

246 247 248
    int pixelNum = n * out_chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;
249 250 251

    if ("nearest" == interp_method) {
      KeNearestNeighborInterpFw<
252
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
253
          input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
254
          out_chw, c, ratio_h, ratio_w, align_corners);
255 256
    } else if ("bilinear" == interp_method) {
      KeBilinearInterpFw<
257
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
258
          input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
259
          out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    }
  }
};

template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* output_grad_data = output_grad->data<T>();
    auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());

    auto& device_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();
    math::SetConstant<platform::CUDADeviceContext, T> zero;
    zero(device_ctx, input_grad, static_cast<T>(0.0));

D
dengkaipeng 已提交
278 279 280 281 282
    int n = input_grad->dims()[0];
    int c = input_grad->dims()[1];
    int in_h = input_grad->dims()[2];
    int in_w = input_grad->dims()[3];

283 284 285
    auto interp_method = ctx.Attr<std::string>("interp_method");
    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
D
dengkaipeng 已提交
286 287 288
    float scale = ctx.Attr<float>("scale");
    if (scale > 0) {
      out_h = in_h * scale;
289
      out_w = in_w * scale;
D
dengkaipeng 已提交
290
    }
291 292 293 294 295 296 297 298 299
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      Tensor sizes;
      framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
      auto size_data = sizes.data<int>();
      out_h = size_data[0];
      out_w = size_data[1];
    }

D
dengkaipeng 已提交
300 301 302
    bool align_corners = ctx.Attr<bool>("align_corners");
    int align_mode = ctx.Attr<int>("align_mode");

303 304 305 306 307
    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = c * in_hw;
    int out_chw = c * out_hw;

T
tink2123 已提交
308 309 310 311 312 313 314
    float ratio_h = 0.f;
    float ratio_w = 0.f;
    if (out_h > 1) {
      ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                                : static_cast<float>(in_h) / out_h;
    }
    if (out_w > 1) {
T
tink2123 已提交
315 316
      ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                                : static_cast<float>(in_w) / out_w;
T
tink2123 已提交
317
    }
318 319 320 321 322 323

    if (in_h == out_h && in_w == out_w) {
      framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
      return;
    }

324 325 326
    int pixelNum = n * out_chw;
    int grid_dim = (pixelNum + 512 - 1) / 512;
    grid_dim = grid_dim > 8 ? 8 : grid_dim;
327 328 329

    if ("nearest" == interp_method) {
      KeNearestNeighborInterpBw<
330
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
331
          input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
332
          out_w, n, out_chw, c, ratio_h, ratio_w, align_corners);
333 334
    } else if ("bilinear" == interp_method) {
      KeBilinearInterpBw<
335
          T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
336
          input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
337
          out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
338 339 340 341 342 343 344 345
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
346
REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::InterpolateOpCUDAKernel<float>,
347 348
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
349 350 351 352 353 354 355
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
356 357
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);