interpolate_op.h 56.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
X
xiaoting 已提交
13
#include <algorithm>
14
#include <string>
15
#include <vector>
16

17
#include "paddle/fluid/framework/op_registry.h"
18 19
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/math_function.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25 26
template <typename T,
          size_t D,
          int MajorType = Eigen::RowMajor,
27
          typename IndexType = Eigen::DenseIndex>
28
using EigenTensor = phi::EigenTensor<T, D, MajorType, IndexType>;
29
using Tensor = phi::DenseTensor;
30
using DataLayout = phi::DataLayout;
31

32
inline std::vector<int> get_new_shape(
33
    const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) {
34 35 36 37
  // get tensor from
  std::vector<int> vec_new_shape;
  for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
    auto tensor = list_new_shape_tensor[i];
38 39
    PADDLE_ENFORCE_EQ(tensor->dims(),
                      phi::make_ddim({1}),
K
Kqnonrime 已提交
40 41 42 43
                      platform::errors::InvalidArgument(
                          "The shape of dimension tensor should be [1],"
                          "but received d%.",
                          tensor->dims()));
44 45
    if (platform::is_gpu_place(tensor->place()) ||
        platform::is_mlu_place(tensor->place())) {
46
      phi::DenseTensor temp;
47
      paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
48 49 50 51 52 53 54 55 56 57
      vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
    } else {
      vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
    }
  }

  return vec_new_shape;
}

template <typename T>
58 59
inline std::vector<T> get_new_data_from_tensor(
    const phi::DenseTensor* new_data_tensor) {
60 61
  std::vector<T> vec_new_data;
  auto* new_data = new_data_tensor->data<T>();
62
  phi::DenseTensor cpu_starts_tensor;
63 64
  if (platform::is_gpu_place(new_data_tensor->place()) ||
      platform::is_mlu_place(new_data_tensor->place())) {
65 66
    paddle::framework::TensorCopySync(
        *new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
67 68 69 70 71 72
    new_data = cpu_starts_tensor.data<T>();
  }
  vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
  return vec_new_data;
}

73
inline void ExtractNCDWH(const framework::DDim& dims,
74 75 76 77 78 79
                         const DataLayout& data_layout,
                         int* N,
                         int* C,
                         int* D,
                         int* H,
                         int* W) {
80
  *N = dims[0];
81 82 83 84 85 86 87

  if (dims.size() == 3) {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
    *D = 1;
    *H = 1;
    *W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
  } else if (dims.size() == 4) {
88 89 90 91 92 93 94 95 96 97 98 99
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
    *D = 1;
    *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
  } else {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
    *D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
    *W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
  }
}

100
template <typename T>
101 102
static void NearestNeighborInterpolate(const phi::DenseTensor& input,
                                       phi::DenseTensor* output,
103 104 105 106 107 108
                                       const float ratio_h,
                                       const float ratio_w,
                                       const int n,
                                       const int c,
                                       const int out_h,
                                       const int out_w,
109 110
                                       const bool align_corners,
                                       const DataLayout& data_layout) {
111 112 113
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
  for (int k = 0; k < out_h; k++) {  // loop for images
114 115
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);
116 117

    for (int l = 0; l < out_w; l++) {
118 119
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);
120 121 122

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
123 124 125 126 127
          if (data_layout == DataLayout::kNCHW) {
            output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
          } else {
            output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
          }
128 129 130 131 132 133
        }
      }
    }
  }
}

134
template <typename T>
135 136
static void LinearInterpolation(const phi::DenseTensor& input,
                                phi::DenseTensor* output,
137 138 139 140 141 142 143
                                const float ratio_w,
                                const int in_w,
                                const int n,
                                const int c,
                                const int out_w,
                                const bool align_corners,
                                const bool align_mode,
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
                                const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 3>::From(input);
  auto output_t = EigenTensor<T, 3>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;                       // w
    int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w;  // w_id

    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;  // w1lambda
    float d_e = 1.f - d_w;                                         // w2lambda
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
  for (int i = 0; i < n; i++) {    // loop for batches
    for (int j = 0; j < c; j++) {  // loop for channels
      for (int l = 0; l < out_w; l++) {
        // linear interpolation
        T out_t;
        if (data_layout == DataLayout::kNCHW) {
          out_t = input_t(i, j, vx_w[l]) * vd_e[l] +
                  input_t(i, j, vx_e[l]) * vd_w[l];
          output_t(i, j, l) = out_t;
        } else {
          out_t = input_t(i, vx_w[l], j) * vd_e[l] +
                  input_t(i, vx_e[l], j) * vd_w[l];
          output_t(i, l, j) = out_t;
        }
      }
    }
  }
}

template <typename T>
199 200
static void LinearInterpolationGrad(const phi::DenseTensor& output_grad,
                                    phi::DenseTensor* input_grad,
201 202 203 204 205 206
                                    const float ratio_w,
                                    const int in_w,
                                    const int n,
                                    const int c,
                                    const int out_w,
                                    const bool align_corners,
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
                                    const int align_mode,
                                    const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 3>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 3>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int l = 0; l < out_w; l++) {
    int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;                       // w
    int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w;  // w_id

    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;  // w1lambda
    float d_e = 1.f - d_w;                                         // w2lambda

    for (int i = 0; i < n; i++) {    // loop for batches
      for (int j = 0; j < c; j++) {  // loop for channels
        // linear interpolation grad
        if (data_layout == DataLayout::kNCHW) {
          const T grad = output_grad_t(i, j, l);
          input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
          input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
        } else {
          const T grad = output_grad_t(i, l, j);
          input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
          input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
        }
      }
    }
  }
}

240
template <typename T>
241 242
static void BilinearInterpolation(const phi::DenseTensor& input,
                                  phi::DenseTensor* output,
243 244 245 246 247 248 249 250
                                  const float ratio_h,
                                  const float ratio_w,
                                  const int in_h,
                                  const int in_w,
                                  const int n,
                                  const int c,
                                  const int out_h,
                                  const int out_w,
251
                                  const bool align_corners,
252 253
                                  const bool align_mode,
                                  const DataLayout data_layout) {
254 255
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
T
tink2123 已提交
256
  bool align_flag = (align_mode == 0 && !align_corners);
257 258 259 260 261 262 263 264 265 266 267

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
T
tink2123 已提交
268 269
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
T
tink2123 已提交
270
    y_n = (y_n > 0) ? y_n : 0;
271
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
272 273 274
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
275
    float d_s = 1.f - d_n;
276 277 278 279 280 281 282
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }
283

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
299 300 301
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
302 303 304 305 306 307 308 309
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }
310

311 312 313 314 315 316 317
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(4)
#endif
  for (int i = 0; i < n; i++) {          // loop for batches
    for (int j = 0; j < c; j++) {        // loop for channels
      for (int k = 0; k < out_h; k++) {  // loop for images
        for (int l = 0; l < out_w; l++) {
318
          // bilinear interpolation
319 320 321
          T out_t;
          if (data_layout == DataLayout::kNCHW) {
            out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
322 323 324
                    input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] +
                    input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] +
                    input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l];
325 326 327 328 329 330 331 332 333
            output_t(i, j, k, l) = out_t;

          } else {
            out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] +
                    input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] +
                    input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] +
                    input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l];
            output_t(i, k, l, j) = out_t;
          }
334 335 336 337 338 339
        }
      }
    }
  }
}

K
Kaipeng Deng 已提交
340
template <typename T>
341 342
static void TrilinearInterpolation(const phi::DenseTensor& input,
                                   phi::DenseTensor* output,
343 344 345 346 347 348 349 350 351 352 353 354 355 356
                                   const float ratio_d,
                                   const float ratio_h,
                                   const float ratio_w,
                                   const int in_d,
                                   const int in_h,
                                   const int in_w,
                                   const int n,
                                   const int c,
                                   const int out_d,
                                   const int out_h,
                                   const int out_w,
                                   const bool align_corners,
                                   const bool align_mode,
                                   const DataLayout& data_layout) {
K
Kaipeng Deng 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
  auto input_t = EigenTensor<T, 5>::From(input);
  auto output_t = EigenTensor<T, 5>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vt_f, vt_b;
  std::vector<float> vd_f, vd_b;
  vt_f.reserve(out_d);
  vt_b.reserve(out_d);
  vd_f.reserve(out_d);
  vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int j = 0; j < out_d; j++) {
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;
    {
      vt_f[j] = t_f;
      vt_b[j] = t_b;
      vd_f[j] = d_f;
      vd_b[j] = d_b;
    }
  }

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
    y_n = (y_n > 0) ? y_n : 0;
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
    float d_s = 1.f - d_n;
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
  for (int b = 0; b < n; b++) {          // loop for batches
    for (int i = 0; i < c; i++) {        // loop for channels
      for (int j = 0; j < out_d; j++) {  // loop for D, H, W
        for (int k = 0; k < out_h; k++) {
          for (int l = 0; l < out_w; l++) {
            // trilinear interpolation
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
            if (data_layout == DataLayout::kNCHW) {
              T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, i, j, k, l) = out_t;
            } else {
              T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, j, k, l, i) = out_t;
            }
K
Kaipeng Deng 已提交
486 487 488 489 490 491 492
          }
        }
      }
    }
  }
}

X
xiaoting 已提交
493 494 495 496 497 498 499 500 501 502 503 504
template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
  return ((A + 2) * x - (A + 3)) * x * x + 1;
}

template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}

template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
505
  T A = static_cast<T>(-0.75);
X
xiaoting 已提交
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525

  T x1 = t;
  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
  coeffs[1] = cubic_convolution1<T>(x1, A);

  // opposite coefficients
  T x2 = 1.0 - t;
  coeffs[2] = cubic_convolution1<T>(x2, A);
  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}

template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
  T coeffs[4];
  get_cubic_upsample_coefficients<T>(coeffs, t);

  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

template <typename T>
526 527
static void BicubicInterpolation(const phi::DenseTensor& input,
                                 phi::DenseTensor* output,
528 529 530 531 532 533 534 535
                                 const float ratio_h,
                                 const float ratio_w,
                                 const int in_h,
                                 const int in_w,
                                 const int n,
                                 const int c,
                                 const int out_h,
                                 const int out_w,
X
xiaoting 已提交
536 537 538 539 540 541 542 543
                                 const bool align_corners,
                                 const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
544
    int input_y = floorf(y_n);
X
xiaoting 已提交
545 546 547 548 549
    const T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
550
      int input_x = floorf(x_n);
X
xiaoting 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
      const T x_t = x_n - input_x;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          T coefficients[4];
          // interp 4 times in x direction
          for (int ii = 0; ii < 4; ii++) {
            int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
                                    static_cast<int>(0));
            int access_x_0 =
                std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
            int access_x_1 =
                std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
            int access_x_2 =
                std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
            int access_x_3 =
                std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
            if (data_layout == DataLayout::kNCHW) {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, j, access_y, access_x_0),
                                  input_t(i, j, access_y, access_x_1),
                                  input_t(i, j, access_y, access_x_2),
573 574
                                  input_t(i, j, access_y, access_x_3),
                                  x_t);
X
xiaoting 已提交
575 576 577 578 579
            } else {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, access_y, access_x_0, j),
                                  input_t(i, access_y, access_x_1, j),
                                  input_t(i, access_y, access_x_2, j),
580 581
                                  input_t(i, access_y, access_x_3, j),
                                  x_t);
X
xiaoting 已提交
582 583 584 585 586
            }
          }

          // interp y direction
          if (data_layout == DataLayout::kNCHW) {
587 588 589 590 591
            output_t(i, j, k, l) = cubic_interp<T>(coefficients[0],
                                                   coefficients[1],
                                                   coefficients[2],
                                                   coefficients[3],
                                                   y_t);
X
xiaoting 已提交
592
          } else {
593 594 595 596 597
            output_t(i, k, l, j) = cubic_interp<T>(coefficients[0],
                                                   coefficients[1],
                                                   coefficients[2],
                                                   coefficients[3],
                                                   y_t);
X
xiaoting 已提交
598 599 600 601 602 603 604
          }
        }
      }
    }
  }
}

605
template <typename T>
606 607
static void NearestNeighborInterpolateGrad(const phi::DenseTensor& output_grad,
                                           phi::DenseTensor* input_grad,
608 609 610 611 612 613 614 615
                                           const float ratio_h,
                                           const float ratio_w,
                                           const int n,
                                           const int c,
                                           const int out_h,
                                           const int out_w,
                                           const bool align_corners,
                                           const DataLayout data_layout) {
616 617
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
618

619
  for (int k = 0; k < out_h; k++) {  // loop for images
620 621
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);
622 623

    for (int l = 0; l < out_w; l++) {
624 625
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);
626 627 628

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
629 630 631 632 633
          if (data_layout == DataLayout::kNCHW) {
            input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
          } else {
            input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
          }
634 635 636 637 638 639 640
        }
      }
    }
  }
}

template <typename T>
641 642
static void BilinearInterpolationGrad(const phi::DenseTensor& output_grad,
                                      phi::DenseTensor* input_grad,
643 644 645 646 647 648 649 650 651 652 653
                                      const float ratio_h,
                                      const float ratio_w,
                                      const int in_h,
                                      const int in_w,
                                      const int n,
                                      const int c,
                                      const int out_h,
                                      const int out_w,
                                      const bool align_corners,
                                      const int align_mode,
                                      const DataLayout data_layout) {
654 655
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
T
tink2123 已提交
656
  bool align_flag = (align_mode == 0 && !align_corners);
657
  for (int k = 0; k < out_h; k++) {  // loop for images
T
tink2123 已提交
658 659
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
T
tink2123 已提交
660
    y_n = (y_n > 0) ? y_n : 0;
661
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
662 663 664
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
665 666 667
    float d_s = 1.f - d_n;

    for (int l = 0; l < out_w; l++) {
T
tink2123 已提交
668 669
      int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                           : static_cast<int>(ratio_w * l);
T
tink2123 已提交
670
      x_w = (x_w > 0) ? x_w : 0;
671
      int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
672 673 674
      float idx_src_x = ratio_w * (l + 0.5) - 0.5;
      idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
      float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
675 676 677 678 679
      float d_e = 1.f - d_w;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bilinear interpolation grad
680 681 682 683 684 685 686 687 688 689 690 691 692
          if (data_layout == DataLayout::kNCHW) {
            const T grad = output_grad_t(i, j, k, l);
            input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
          } else {
            const T grad = output_grad_t(i, k, l, j);
            input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
          }
693 694 695 696 697
        }
      }
    }
  }
}
K
Kaipeng Deng 已提交
698

699
template <typename T>
700 701
static void TrilinearInterpolationGrad(const phi::DenseTensor& output_grad,
                                       phi::DenseTensor* input_grad,
702 703 704 705 706 707 708 709 710 711 712 713 714 715
                                       const float ratio_d,
                                       const float ratio_h,
                                       const float ratio_w,
                                       const int in_d,
                                       const int in_h,
                                       const int in_w,
                                       const int n,
                                       const int c,
                                       const int out_d,
                                       const int out_h,
                                       const int out_w,
                                       const bool align_corners,
                                       const int align_mode,
                                       const DataLayout data_layout) {
K
Kaipeng Deng 已提交
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
  auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int j = 0; j < out_d; j++) {  // loop for D
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;

    for (int k = 0; k < out_h; k++) {  // loop for H
      int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                           : static_cast<int>(ratio_h * k);
      y_n = (y_n > 0) ? y_n : 0;
      int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
      float idx_src_y = ratio_h * (k + 0.5) - 0.5;
      idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
      float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
      float d_s = 1.f - d_n;

      for (int l = 0; l < out_w; l++) {  // loop for W
        int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                             : static_cast<int>(ratio_w * l);
        x_w = (x_w > 0) ? x_w : 0;
        int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
        float idx_src_x = ratio_w * (l + 0.5) - 0.5;
        idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
        float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
        float d_e = 1.f - d_w;

        for (int b = 0; b < n; b++) {    // loop for batches
          for (int i = 0; i < c; i++) {  // loop for channels
            // trilinear interpolation grad
752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
            if (data_layout == DataLayout::kNCHW) {
              const T grad = output_grad_t(b, i, j, k, l);
              input_grad_t(b, i, t_f, y_n, x_w) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, i, t_f, y_n, x_e) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, i, t_f, y_s, x_w) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, i, t_f, y_s, x_e) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, i, t_b, y_n, x_w) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, i, t_b, y_n, x_e) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, i, t_b, y_s, x_w) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, i, t_b, y_s, x_e) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            } else {
              const T grad = output_grad_t(b, j, k, l, i);
              input_grad_t(b, t_f, y_n, x_w, i) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, t_f, y_n, x_e, i) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, t_f, y_s, x_w, i) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, t_f, y_s, x_e, i) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, t_b, y_n, x_w, i) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, t_b, y_n, x_e, i) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, t_b, y_s, x_w, i) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, t_b, y_s, x_e, i) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            }
K
Kaipeng Deng 已提交
789 790 791 792 793 794
          }
        }
      }
    }
  }
}
795

X
xiaoting 已提交
796
template <typename T>
797 798
static void BicubicInterpolationGrad(const phi::DenseTensor& output_grad,
                                     phi::DenseTensor* input_grad,
799 800 801 802 803 804 805 806
                                     const float ratio_h,
                                     const float ratio_w,
                                     const int in_h,
                                     const int in_w,
                                     const int n,
                                     const int c,
                                     const int out_h,
                                     const int out_w,
X
xiaoting 已提交
807 808 809 810 811 812 813 814
                                     const bool align_corners,
                                     const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
815
    int input_y = floorf(y_n);
X
xiaoting 已提交
816 817 818 819 820
    T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
821
      int input_x = floorf(x_n);
X
xiaoting 已提交
822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855
      T x_t = x_n - input_x;

      T x_coeffs[4];
      T y_coeffs[4];

      get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
      get_cubic_upsample_coefficients<T>(y_coeffs, y_t);

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bicubic interpolation grad
          for (int ii = 0; ii < 4; ii++) {
            for (int jj = 0; jj < 4; jj++) {
              int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
                                      static_cast<int>(0));
              int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
                                      static_cast<int>(0));
              if (data_layout == DataLayout::kNCHW) {
                T grad = output_grad_t(i, j, k, l);
                input_grad_t(i, j, access_y, access_x) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              } else {
                T grad = output_grad_t(i, k, l, j);
                input_grad_t(i, access_y, access_x, j) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              }
            }
          }
        }
      }
    }
  }
}

856 857
template <typename T>
static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
858 859
                                const phi::DenseTensor& input,
                                phi::DenseTensor* output) {
860
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
861
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
862 863 864 865 866 867 868 869
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_w = ctx.Attr<int>("out_w");
870
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
871 872 873 874 875 876
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_w = new_size[0];
  } else {
    float scale;
877
    auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
878 879 880 881 882 883 884 885 886
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale = scale_data[0];
    } else {
      scale = ctx.Attr<float>("scale");
    }
    if (scale > 0) {
      out_w = static_cast<int>(in_w * scale);
    }
887
    auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
888 889 890 891 892
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_w = out_size_data[0];
    }
  }
893 894
  PADDLE_ENFORCE_GT(out_w,
                    0,
895 896 897
                    platform::errors::InvalidArgument(
                        "out_w in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_w};
  } else {
    dim_out = {n, out_w, c};
  }
  output->mutable_data<T>(dim_out, ctx.GetPlace());

  if (in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_w = 0.f;
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
  if ("linear" == interp_method) {
917 918 919 920 921 922 923 924 925 926
    LinearInterpolation<T>(input,
                           output,
                           ratio_w,
                           in_w,
                           n,
                           c,
                           out_w,
                           align_corners,
                           align_mode,
                           data_layout);
927 928 929
  }
}

K
Kaipeng Deng 已提交
930 931
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
932 933
                                const phi::DenseTensor& input,
                                phi::DenseTensor* output) {
934
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
935
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
936 937
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
938 939 940 941 942 943 944

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
D
dengkaipeng 已提交
945

946
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
947 948 949 950 951 952 953
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  } else {
    float scale;
954
    auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
955 956 957 958 959 960 961 962 963 964
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale = scale_data[0];
    } else {
      scale = ctx.Attr<float>("scale");
    }
    if (scale > 0) {
      out_h = static_cast<int>(in_h * scale);
      out_w = static_cast<int>(in_w * scale);
    }
965
    auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
966 967 968 969 970
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_h = out_size_data[0];
      out_w = out_size_data[1];
    }
K
Kaipeng Deng 已提交
971
  }
972 973
  PADDLE_ENFORCE_GT(out_h,
                    0,
974 975 976
                    platform::errors::InvalidArgument(
                        "out_h in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
977 978
  PADDLE_ENFORCE_GT(out_w,
                    0,
979 980 981
                    platform::errors::InvalidArgument(
                        "out_w in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
982 983 984 985 986 987 988
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_h, out_w};
  } else {
    dim_out = {n, out_h, out_w, c};
  }
  output->mutable_data<T>(dim_out, ctx.GetPlace());
D
dengkaipeng 已提交
989

K
Kaipeng Deng 已提交
990 991 992 993
  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }
994

K
Kaipeng Deng 已提交
995 996 997 998 999 1000 1001 1002 1003 1004
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
T
tink2123 已提交
1005

K
Kaipeng Deng 已提交
1006
  if ("bilinear" == interp_method) {
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
    BilinearInterpolation<T>(input,
                             output,
                             ratio_h,
                             ratio_w,
                             in_h,
                             in_w,
                             n,
                             c,
                             out_h,
                             out_w,
                             align_corners,
                             align_mode,
1019
                             data_layout);
K
Kaipeng Deng 已提交
1020
  } else if ("nearest" == interp_method) {
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
    NearestNeighborInterpolate<T>(input,
                                  output,
                                  ratio_h,
                                  ratio_w,
                                  n,
                                  c,
                                  out_h,
                                  out_w,
                                  align_corners,
                                  data_layout);
X
xiaoting 已提交
1031
  } else if ("bicubic" == interp_method) {
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
    BicubicInterpolation<T>(input,
                            output,
                            ratio_h,
                            ratio_w,
                            in_h,
                            in_w,
                            n,
                            c,
                            out_h,
                            out_w,
                            align_corners,
                            data_layout);
K
Kaipeng Deng 已提交
1044 1045
  }
}
1046

K
Kaipeng Deng 已提交
1047 1048
template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
1049 1050
                                const phi::DenseTensor& input,
                                phi::DenseTensor* output) {
1051
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
1052
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
1053 1054
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
1055 1056 1057 1058 1059 1060 1061 1062 1063

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");

1064
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
1065 1066 1067 1068 1069 1070 1071 1072
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  } else {
    float scale;
1073
    auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale = scale_data[0];
    } else {
      scale = ctx.Attr<float>("scale");
    }
    if (scale > 0) {
      out_d = static_cast<int>(in_d * scale);
      out_h = static_cast<int>(in_h * scale);
      out_w = static_cast<int>(in_w * scale);
    }
1085
    auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
1086 1087 1088 1089 1090 1091
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_d = out_size_data[0];
      out_h = out_size_data[1];
      out_w = out_size_data[2];
    }
K
Kaipeng Deng 已提交
1092
  }
1093 1094
  PADDLE_ENFORCE_GT(out_d,
                    0,
1095 1096 1097
                    platform::errors::InvalidArgument(
                        "out_d in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
1098 1099
  PADDLE_ENFORCE_GT(out_h,
                    0,
1100 1101 1102
                    platform::errors::InvalidArgument(
                        "out_h in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
1103 1104
  PADDLE_ENFORCE_GT(out_w,
                    0,
1105 1106 1107
                    platform::errors::InvalidArgument(
                        "out_w in Attr(out_shape) of Op(interpolate) "
                        "should be greater than 0."));
1108 1109 1110 1111 1112 1113 1114 1115 1116

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_d, out_h, out_w};
  } else {
    dim_out = {n, out_d, out_h, out_w, c};
  }

  output->mutable_data<T>(dim_out, ctx.GetPlace());
K
Kaipeng Deng 已提交
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
1137
  }
K
Kaipeng Deng 已提交
1138 1139

  if ("trilinear" == interp_method) {
1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
    TrilinearInterpolation<T>(input,
                              output,
                              ratio_d,
                              ratio_h,
                              ratio_w,
                              in_d,
                              in_h,
                              in_w,
                              n,
                              c,
                              out_d,
                              out_h,
                              out_w,
                              align_corners,
                              align_mode,
                              data_layout);
K
Kaipeng Deng 已提交
1156 1157
  }
}
1158

1159 1160
template <typename T>
static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
1161 1162 1163
                                phi::DenseTensor* input_grad,
                                const phi::DenseTensor& output_grad) {
  auto* input = ctx.Input<phi::DenseTensor>("X");
1164
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
1165
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
1166 1167 1168 1169 1170 1171 1172 1173 1174
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_w = ctx.Attr<int>("out_w");
  float scale;
1175
  auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
1176 1177 1178 1179 1180 1181 1182 1183 1184
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale = scale_data[0];
  } else {
    scale = ctx.Attr<float>("scale");
  }
  if (scale > 0) {
    out_w = static_cast<int>(in_w * scale);
  }
1185
  auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
1186 1187 1188 1189
  if (out_size != nullptr) {
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
    out_w = out_size_data[0];
  }
1190
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_w = new_size[0];
  }

  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_w};
  } else {
    dim_grad = {n, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());

L
Leo Chen 已提交
1205 1206
  auto& device_ctx = ctx.template device_context<phi::CPUContext>();
  phi::funcs::SetConstant<phi::CPUContext, T> zero;
1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_w = 0.f;
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
  if ("linear" == interp_method) {
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
    LinearInterpolationGrad<T>(output_grad,
                               input_grad,
                               ratio_w,
                               in_w,
                               n,
                               c,
                               out_w,
                               align_corners,
                               align_mode,
                               data_layout);
1230 1231 1232
  }
}

1233
template <typename T>
K
Kaipeng Deng 已提交
1234
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
1235 1236 1237
                                phi::DenseTensor* input_grad,
                                const phi::DenseTensor& output_grad) {
  auto* input = ctx.Input<phi::DenseTensor>("X");
1238
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
1239
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
1240 1241
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
1242 1243 1244 1245 1246 1247 1248

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
1249
  float scale;
1250
  auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
1251 1252 1253 1254 1255 1256
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale = scale_data[0];
  } else {
    scale = ctx.Attr<float>("scale");
  }
K
Kaipeng Deng 已提交
1257 1258 1259 1260
  if (scale > 0) {
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }
1261
  auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
K
Kaipeng Deng 已提交
1262
  if (out_size != nullptr) {
1263
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
K
Kaipeng Deng 已提交
1264 1265 1266
    out_h = out_size_data[0];
    out_w = out_size_data[1];
  }
1267
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
1268 1269 1270 1271 1272 1273
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  }
D
dengkaipeng 已提交
1274

1275 1276 1277 1278 1279 1280 1281 1282
  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_h, in_w};
  } else {
    dim_grad = {n, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());

L
Leo Chen 已提交
1283 1284
  auto& device_ctx = ctx.template device_context<phi::CPUContext>();
  phi::funcs::SetConstant<phi::CPUContext, T> zero;
K
Kaipeng Deng 已提交
1285
  zero(device_ctx, input_grad, static_cast<T>(0.0));
D
dengkaipeng 已提交
1286

K
Kaipeng Deng 已提交
1287 1288 1289 1290
  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }
D
dengkaipeng 已提交
1291

K
Kaipeng Deng 已提交
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  if ("bilinear" == interp_method) {
1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
    BilinearInterpolationGrad<T>(output_grad,
                                 input_grad,
                                 ratio_h,
                                 ratio_w,
                                 in_h,
                                 in_w,
                                 n,
                                 c,
                                 out_h,
                                 out_w,
                                 align_corners,
                                 align_mode,
                                 data_layout);
K
Kaipeng Deng 已提交
1317
  } else if ("nearest" == interp_method) {
1318 1319 1320 1321 1322 1323 1324 1325 1326
    NearestNeighborInterpolateGrad<T>(output_grad,
                                      input_grad,
                                      ratio_h,
                                      ratio_w,
                                      n,
                                      c,
                                      out_h,
                                      out_w,
                                      align_corners,
1327
                                      data_layout);
X
xiaoting 已提交
1328
  } else if ("bicubic" == interp_method) {
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
    BicubicInterpolationGrad<T>(output_grad,
                                input_grad,
                                ratio_h,
                                ratio_w,
                                in_h,
                                in_w,
                                n,
                                c,
                                out_h,
                                out_w,
                                align_corners,
X
xiaoting 已提交
1340
                                data_layout);
K
Kaipeng Deng 已提交
1341 1342
  }
}
D
dengkaipeng 已提交
1343

K
Kaipeng Deng 已提交
1344 1345
template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
1346
                                phi::DenseTensor* input_grad,
1347
                                const Tensor output_grad) {
1348
  auto* input = ctx.Input<phi::DenseTensor>("X");
1349
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
1350
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
1351 1352
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
1353 1354 1355 1356 1357 1358 1359 1360

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
1361
  float scale;
1362
  auto scale_tensor = ctx.Input<phi::DenseTensor>("Scale");
1363 1364 1365 1366 1367 1368
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale = scale_data[0];
  } else {
    scale = ctx.Attr<float>("scale");
  }
K
Kaipeng Deng 已提交
1369 1370 1371 1372 1373
  if (scale > 0) {
    out_d = static_cast<int>(in_d * scale);
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }
1374
  auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
K
Kaipeng Deng 已提交
1375
  if (out_size != nullptr) {
1376
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
K
Kaipeng Deng 已提交
1377 1378 1379 1380
    out_d = out_size_data[0];
    out_h = out_size_data[1];
    out_w = out_size_data[2];
  }
1381
  auto list_new_size_tensor = ctx.MultiInput<phi::DenseTensor>("SizeTensor");
1382 1383 1384 1385 1386 1387 1388
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  }
1389

1390 1391 1392 1393 1394 1395 1396
  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_d, in_h, in_w};
  } else {
    dim_grad = {n, in_d, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
L
Leo Chen 已提交
1397 1398
  auto& device_ctx = ctx.template device_context<phi::CPUContext>();
  phi::funcs::SetConstant<phi::CPUContext, T> zero;
K
Kaipeng Deng 已提交
1399 1400 1401 1402 1403 1404
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }
1405

K
Kaipeng Deng 已提交
1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420
  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
T
tink2123 已提交
1421

K
Kaipeng Deng 已提交
1422
  if ("trilinear" == interp_method) {
1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438
    TrilinearInterpolationGrad<T>(output_grad,
                                  input_grad,
                                  ratio_d,
                                  ratio_h,
                                  ratio_w,
                                  in_d,
                                  in_h,
                                  in_w,
                                  n,
                                  c,
                                  out_d,
                                  out_h,
                                  out_w,
                                  align_corners,
                                  align_mode,
                                  data_layout);
K
Kaipeng Deng 已提交
1439 1440 1441 1442 1443 1444 1445
  }
}

template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
1446 1447
    auto* input = ctx.Input<phi::DenseTensor>("X");
    auto* output = ctx.Output<phi::DenseTensor>("Out");
K
Kaipeng Deng 已提交
1448 1449

    auto input_dims = input->dims();
1450 1451 1452
    if (input_dims.size() == 3) {  // 1D interpolation
      Interpolate1DCPUFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 4) {  // 2D interpolation
K
Kaipeng Deng 已提交
1453 1454 1455
      Interpolate2DCPUFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 5) {  // 3D interpolation
      Interpolate3DCPUFwd<T>(ctx, *input, output);
T
tink2123 已提交
1456
    }
K
Kaipeng Deng 已提交
1457 1458 1459 1460 1461 1462 1463
  }
};

template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
1464 1465 1466 1467
    auto* input_grad =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto* output_grad =
        ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
1468

K
Kaipeng Deng 已提交
1469
    auto output_grad_dims = output_grad->dims();
1470 1471 1472
    if (output_grad_dims.size() == 3) {  // 1D interpolation grad
      Interpolate1DCPUBwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 4) {  // 2D interpolation grad
K
Kaipeng Deng 已提交
1473 1474 1475
      Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 5) {  // 3D interpolation grad
      Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
1476 1477 1478 1479 1480 1481
    }
  }
};

}  // namespace operators
}  // namespace paddle