interpolate_v2_op.h 62.5 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace operators {

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;

inline std::vector<int> get_new_shape(
    const std::vector<const Tensor*>& list_new_shape_tensor) {
  // get tensor from
  std::vector<int> vec_new_shape;
  for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
    auto tensor = list_new_shape_tensor[i];
K
Kqnonrime 已提交
35 36 37 38 39
    PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
                      platform::errors::InvalidArgument(
                          "The shape of dimension tensor should be [1],"
                          "but received d%.",
                          tensor->dims()));
X
xiaoting 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    if (platform::is_gpu_place(tensor->place())) {
      framework::Tensor temp;
      TensorCopySync(*tensor, platform::CPUPlace(), &temp);
      vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
    } else {
      vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
    }
  }

  return vec_new_shape;
}

template <typename T>
inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
  std::vector<T> vec_new_data;
  auto* new_data = new_data_tensor->data<T>();
  framework::Tensor cpu_starts_tensor;
  if (platform::is_gpu_place(new_data_tensor->place())) {
    TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
    new_data = cpu_starts_tensor.data<T>();
  }
61 62 63 64 65 66
#ifdef PADDLE_WITH_ASCEND_CL
  if (platform::is_npu_place(new_data_tensor->place())) {
    TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
    new_data = cpu_starts_tensor.data<T>();
  }
#endif
X
xiaoting 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
  return vec_new_data;
}

inline void ExtractNCDWH(const framework::DDim& dims,
                         const DataLayout& data_layout, int* N, int* C, int* D,
                         int* H, int* W) {
  *N = dims[0];

  if (dims.size() == 3) {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
    *D = 1;
    *H = 1;
    *W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
  } else if (dims.size() == 4) {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
    *D = 1;
    *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
  } else {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
    *D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
    *W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
  }
}

template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
                                       const float ratio_h, const float ratio_w,
                                       const int n, const int c,
                                       const int out_h, const int out_w,
                                       const bool align_corners,
                                       const DataLayout& data_layout) {
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
  for (int k = 0; k < out_h; k++) {  // loop for images
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);

    for (int l = 0; l < out_w; l++) {
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          if (data_layout == DataLayout::kNCHW) {
            output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
          } else {
            output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
          }
        }
      }
    }
  }
}

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
template <typename T>
static void NearestNeighbor3DInterpolate(
    const Tensor& input, Tensor* output, const float ratio_d,
    const float ratio_h, const float ratio_w, const int n, const int c,
    const int out_d, const int out_h, const int out_w, const bool align_corners,
    const DataLayout& data_layout) {
  auto input_t = EigenTensor<T, 5>::From(input);
  auto output_t = EigenTensor<T, 5>::From(*output);
  for (int d = 0; d < out_d; d++) {  // loop for images
    int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
                               : static_cast<int>(ratio_d * d);
    for (int k = 0; k < out_h; k++) {
      int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                                 : static_cast<int>(ratio_h * k);

      for (int l = 0; l < out_w; l++) {
        int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                   : static_cast<int>(ratio_w * l);

        for (int i = 0; i < n; i++) {    // loop for batches
          for (int j = 0; j < c; j++) {  // loop for channels
            if (data_layout == DataLayout::kNCHW) {
              output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l);
            } else {  // NDHWC
              output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j);
            }
          }
        }
      }
    }
  }
}

X
xiaoting 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output,
                                const float ratio_w, const int in_w,
                                const int n, const int c, const int out_w,
                                const bool align_corners, const bool align_mode,
                                const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 3>::From(input);
  auto output_t = EigenTensor<T, 3>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;                       // w
    int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w;  // w_id

    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;  // w1lambda
    float d_e = 1.f - d_w;                                         // w2lambda
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
  for (int i = 0; i < n; i++) {    // loop for batches
    for (int j = 0; j < c; j++) {  // loop for channels
      for (int l = 0; l < out_w; l++) {
        // linear interpolation
        T out_t;
        if (data_layout == DataLayout::kNCHW) {
          out_t = input_t(i, j, vx_w[l]) * vd_e[l] +
                  input_t(i, j, vx_e[l]) * vd_w[l];
          output_t(i, j, l) = out_t;
        } else {
          out_t = input_t(i, vx_w[l], j) * vd_e[l] +
                  input_t(i, vx_e[l], j) * vd_w[l];
          output_t(i, l, j) = out_t;
        }
      }
    }
  }
}

template <typename T>
static void LinearInterpolationGrad(const Tensor& output_grad,
                                    Tensor* input_grad, const float ratio_w,
                                    const int in_w, const int n, const int c,
                                    const int out_w, const bool align_corners,
                                    const int align_mode,
                                    const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 3>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 3>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int l = 0; l < out_w; l++) {
    int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;                       // w
    int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w;  // w_id

    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;  // w1lambda
    float d_e = 1.f - d_w;                                         // w2lambda

    for (int i = 0; i < n; i++) {    // loop for batches
      for (int j = 0; j < c; j++) {  // loop for channels
        // linear interpolation grad
        if (data_layout == DataLayout::kNCHW) {
          const T grad = output_grad_t(i, j, l);
          input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
          input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
        } else {
          const T grad = output_grad_t(i, l, j);
          input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
          input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
        }
      }
    }
  }
}

template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
                                  const float ratio_h, const float ratio_w,
                                  const int in_h, const int in_w, const int n,
                                  const int c, const int out_h, const int out_w,
                                  const bool align_corners,
                                  const bool align_mode,
                                  const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
    y_n = (y_n > 0) ? y_n : 0;
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
    float d_s = 1.f - d_n;
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(4)
#endif
  for (int i = 0; i < n; i++) {          // loop for batches
    for (int j = 0; j < c; j++) {        // loop for channels
      for (int k = 0; k < out_h; k++) {  // loop for images
        for (int l = 0; l < out_w; l++) {
          // bilinear interpolation
          T out_t;
          if (data_layout == DataLayout::kNCHW) {
            out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
                    input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] +
                    input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] +
                    input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l];
            output_t(i, j, k, l) = out_t;

          } else {
            out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] +
                    input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] +
                    input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] +
                    input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l];
            output_t(i, k, l, j) = out_t;
          }
        }
      }
    }
  }
}

template <typename T>
static void TrilinearInterpolation(
    const Tensor& input, Tensor* output, const float ratio_d,
    const float ratio_h, const float ratio_w, const int in_d, const int in_h,
    const int in_w, const int n, const int c, const int out_d, const int out_h,
    const int out_w, const bool align_corners, const bool align_mode,
    const DataLayout& data_layout) {
  auto input_t = EigenTensor<T, 5>::From(input);
  auto output_t = EigenTensor<T, 5>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vt_f, vt_b;
  std::vector<float> vd_f, vd_b;
  vt_f.reserve(out_d);
  vt_b.reserve(out_d);
  vd_f.reserve(out_d);
  vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int j = 0; j < out_d; j++) {
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;
    {
      vt_f[j] = t_f;
      vt_b[j] = t_b;
      vd_f[j] = d_f;
      vd_b[j] = d_b;
    }
  }

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
    y_n = (y_n > 0) ? y_n : 0;
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
    float d_s = 1.f - d_n;
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
  for (int b = 0; b < n; b++) {          // loop for batches
    for (int i = 0; i < c; i++) {        // loop for channels
      for (int j = 0; j < out_d; j++) {  // loop for D, H, W
        for (int k = 0; k < out_h; k++) {
          for (int l = 0; l < out_w; l++) {
            // trilinear interpolation
            if (data_layout == DataLayout::kNCHW) {
              T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, i, j, k, l) = out_t;
            } else {
              T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, j, k, l, i) = out_t;
            }
          }
        }
      }
    }
  }
}

template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
  return ((A + 2) * x - (A + 3)) * x * x + 1;
}

template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}

template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
  T A = -0.75;

  T x1 = t;
  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
  coeffs[1] = cubic_convolution1<T>(x1, A);

  // opposite coefficients
  T x2 = 1.0 - t;
  coeffs[2] = cubic_convolution1<T>(x2, A);
  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}

template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
  T coeffs[4];
  get_cubic_upsample_coefficients<T>(coeffs, t);

  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
                                 const float ratio_h, const float ratio_w,
                                 const int in_h, const int in_w, const int n,
                                 const int c, const int out_h, const int out_w,
                                 const bool align_corners,
                                 const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
    int input_y = floorf(y_n);
    const T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
      int input_x = floorf(x_n);
      const T x_t = x_n - input_x;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          T coefficients[4];
          // interp 4 times in x direction
          for (int ii = 0; ii < 4; ii++) {
            int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
                                    static_cast<int>(0));
            int access_x_0 =
                std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
            int access_x_1 =
                std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
            int access_x_2 =
                std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
            int access_x_3 =
                std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
            if (data_layout == DataLayout::kNCHW) {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, j, access_y, access_x_0),
                                  input_t(i, j, access_y, access_x_1),
                                  input_t(i, j, access_y, access_x_2),
                                  input_t(i, j, access_y, access_x_3), x_t);
            } else {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, access_y, access_x_0, j),
                                  input_t(i, access_y, access_x_1, j),
                                  input_t(i, access_y, access_x_2, j),
                                  input_t(i, access_y, access_x_3, j), x_t);
            }
          }

          // interp y direction
          if (data_layout == DataLayout::kNCHW) {
            output_t(i, j, k, l) =
                cubic_interp<T>(coefficients[0], coefficients[1],
                                coefficients[2], coefficients[3], y_t);
          } else {
            output_t(i, k, l, j) =
                cubic_interp<T>(coefficients[0], coefficients[1],
                                coefficients[2], coefficients[3], y_t);
          }
        }
      }
    }
  }
}

template <typename T>
static void NearestNeighborInterpolateGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
    const float ratio_w, const int n, const int c, const int out_h,
    const int out_w, const bool align_corners, const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);

  for (int k = 0; k < out_h; k++) {  // loop for images
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);

    for (int l = 0; l < out_w; l++) {
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          if (data_layout == DataLayout::kNCHW) {
            input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
          } else {
            input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
          }
        }
      }
    }
  }
}

620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
template <typename T>
static void NearestNeighbor3DInterpolateGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
    const float ratio_h, const float ratio_w, const int n, const int c,
    const int out_d, const int out_h, const int out_w, const bool align_corners,
    const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 5>::From(output_grad);

  for (int d = 0; d < out_d; d++) {
    int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
                               : static_cast<int>(ratio_d * d);
    for (int k = 0; k < out_h; k++) {  // loop for images
      int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                                 : static_cast<int>(ratio_h * k);

      for (int l = 0; l < out_w; l++) {
        int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                   : static_cast<int>(ratio_w * l);

        for (int i = 0; i < n; i++) {    // loop for batches
          for (int j = 0; j < c; j++) {  // loop for channels
            if (data_layout == DataLayout::kNCHW) {
              input_grad_t(i, j, in_d, in_k, in_l) +=
                  output_grad_t(i, j, d, k, l);
            } else {
              input_grad_t(i, in_d, in_k, in_l, j) +=
                  output_grad_t(i, d, k, l, j);
            }
          }
        }
      }
    }
  }
}

X
xiaoting 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
template <typename T>
static void BilinearInterpolationGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
    const float ratio_w, const int in_h, const int in_w, const int n,
    const int c, const int out_h, const int out_w, const bool align_corners,
    const int align_mode, const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int k = 0; k < out_h; k++) {  // loop for images
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
    y_n = (y_n > 0) ? y_n : 0;
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
    float d_s = 1.f - d_n;

    for (int l = 0; l < out_w; l++) {
      int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                           : static_cast<int>(ratio_w * l);
      x_w = (x_w > 0) ? x_w : 0;
      int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
      float idx_src_x = ratio_w * (l + 0.5) - 0.5;
      idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
      float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
      float d_e = 1.f - d_w;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bilinear interpolation grad
          if (data_layout == DataLayout::kNCHW) {
            const T grad = output_grad_t(i, j, k, l);
            input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
          } else {
            const T grad = output_grad_t(i, k, l, j);
            input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
          }
        }
      }
    }
  }
}

template <typename T>
static void TrilinearInterpolationGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
    const float ratio_h, const float ratio_w, const int in_d, const int in_h,
    const int in_w, const int n, const int c, const int out_d, const int out_h,
    const int out_w, const bool align_corners, const int align_mode,
    const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int j = 0; j < out_d; j++) {  // loop for D
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;

    for (int k = 0; k < out_h; k++) {  // loop for H
      int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                           : static_cast<int>(ratio_h * k);
      y_n = (y_n > 0) ? y_n : 0;
      int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
      float idx_src_y = ratio_h * (k + 0.5) - 0.5;
      idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
      float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
      float d_s = 1.f - d_n;

      for (int l = 0; l < out_w; l++) {  // loop for W
        int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                             : static_cast<int>(ratio_w * l);
        x_w = (x_w > 0) ? x_w : 0;
        int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
        float idx_src_x = ratio_w * (l + 0.5) - 0.5;
        idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
        float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
        float d_e = 1.f - d_w;

        for (int b = 0; b < n; b++) {    // loop for batches
          for (int i = 0; i < c; i++) {  // loop for channels
            // trilinear interpolation grad
            if (data_layout == DataLayout::kNCHW) {
              const T grad = output_grad_t(b, i, j, k, l);
              input_grad_t(b, i, t_f, y_n, x_w) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, i, t_f, y_n, x_e) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, i, t_f, y_s, x_w) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, i, t_f, y_s, x_e) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, i, t_b, y_n, x_w) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, i, t_b, y_n, x_e) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, i, t_b, y_s, x_w) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, i, t_b, y_s, x_e) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            } else {
              const T grad = output_grad_t(b, j, k, l, i);
              input_grad_t(b, t_f, y_n, x_w, i) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, t_f, y_n, x_e, i) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, t_f, y_s, x_w, i) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, t_f, y_s, x_e, i) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, t_b, y_n, x_w, i) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, t_b, y_n, x_e, i) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, t_b, y_s, x_w, i) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, t_b, y_s, x_e, i) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            }
          }
        }
      }
    }
  }
}

template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
                                     Tensor* input_grad, const float ratio_h,
                                     const float ratio_w, const int in_h,
                                     const int in_w, const int n, const int c,
                                     const int out_h, const int out_w,
                                     const bool align_corners,
                                     const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
    int input_y = floorf(y_n);
    T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
      int input_x = floorf(x_n);
      T x_t = x_n - input_x;

      T x_coeffs[4];
      T y_coeffs[4];

      get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
      get_cubic_upsample_coefficients<T>(y_coeffs, y_t);

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bicubic interpolation grad
          for (int ii = 0; ii < 4; ii++) {
            for (int jj = 0; jj < 4; jj++) {
              int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
                                      static_cast<int>(0));
              int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
                                      static_cast<int>(0));
              if (data_layout == DataLayout::kNCHW) {
                T grad = output_grad_t(i, j, k, l);
                input_grad_t(i, j, access_y, access_x) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              } else {
                T grad = output_grad_t(i, k, l, j);
                input_grad_t(i, access_y, access_x, j) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              }
            }
          }
        }
      }
    }
  }
}

template <typename T>
static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
                                const Tensor& input, Tensor* output) {
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_w = ctx.Attr<int>("out_w");
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
863
  float scale_w = -1.;
X
xiaoting 已提交
864 865 866 867 868
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_w = new_size[0];
  } else {
869
    // float scale_w = -1;
X
xiaoting 已提交
870 871 872 873 874
    auto scale_tensor = ctx.Input<Tensor>("Scale");
    auto scale = ctx.Attr<std::vector<float>>("scale");
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale_w = scale_data[0];
K
Kqnonrime 已提交
875 876 877 878 879 880
      PADDLE_ENFORCE_EQ(
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
X
xiaoting 已提交
881 882 883 884
    } else {
      if (scale.size() > 0) {
        scale_w = scale[0];

K
Kqnonrime 已提交
885 886 887 888 889 890
        PADDLE_ENFORCE_EQ(
            scale_w > 0, true,
            platform::errors::InvalidArgument(
                "The scale_w in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_w));
X
xiaoting 已提交
891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919
      }
    }
    if (scale_w > 0.) {
      out_w = static_cast<int>(in_w * scale_w);
    }
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_w = out_size_data[0];
    }
  }
  PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
                                  "out_w in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_w};
  } else {
    dim_out = {n, out_w, c};
  }
  output->mutable_data<T>(dim_out, ctx.GetPlace());

  if (in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_w = 0.f;
  if (out_w > 1) {
920 921 922
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
923
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
924
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
  }
  if ("linear" == interp_method) {
    LinearInterpolation<T>(input, output, ratio_w, in_w, n, c, out_w,
                           align_corners, align_mode, data_layout);
  }
}

template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
                                const Tensor& input, Tensor* output) {
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
946 947
  float scale_h = -1;
  float scale_w = -1;
X
xiaoting 已提交
948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967

  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  } else {
    auto scale_tensor = ctx.Input<Tensor>("Scale");
    auto scale = ctx.Attr<std::vector<float>>("scale");
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      if (scale_data.size() > 1) {
        scale_h = scale_data[0];
        scale_w = scale_data[1];
      } else {
        scale_h = scale_data[0];
        scale_w = scale_data[0];
      }
      PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
968 969 970 971 972 973 974 975 976 977 978
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
      PADDLE_ENFORCE_EQ(
          scale_h > 0, true,
          platform::errors::InvalidArgument(
              "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_h));
X
xiaoting 已提交
979 980 981 982 983 984
    } else {
      if (scale.size() > 1) {
        scale_h = scale[0];
        scale_w = scale[1];

        PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
985 986 987 988 989 990 991 992 993 994 995
            scale_w > 0, true,
            platform::errors::InvalidArgument(
                "The scale_w in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_w));
        PADDLE_ENFORCE_EQ(
            scale_h > 0, true,
            platform::errors::InvalidArgument(
                "The scale_h in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_h));
X
xiaoting 已提交
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
      }
    }
    if (scale_h > 0. && scale_w > 0.) {
      out_h = static_cast<int>(in_h * scale_h);
      out_w = static_cast<int>(in_w * scale_w);
    }
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_h = out_size_data[0];
      out_w = out_size_data[1];
    }
  }
  PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
                                  "out_h in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));
  PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
                                  "out_w in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_h, out_w};
  } else {
    dim_out = {n, out_h, out_w, c};
  }
  output->mutable_data<T>(dim_out, ctx.GetPlace());

  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
1031 1032 1033
    float new_scale_h = 0.f;
    new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
                                : static_cast<float>(in_h) / out_h;
X
xiaoting 已提交
1034
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
1035
                              : static_cast<float>(new_scale_h);
X
xiaoting 已提交
1036 1037
  }
  if (out_w > 1) {
1038 1039 1040
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
1041
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
1042
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073
  }

  if ("bilinear" == interp_method) {
    BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
                             out_h, out_w, align_corners, align_mode,
                             data_layout);
  } else if ("nearest" == interp_method) {
    NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
                                  out_w, align_corners, data_layout);
  } else if ("bicubic" == interp_method) {
    BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
                            out_h, out_w, align_corners, data_layout);
  }
}

template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
                                const Tensor& input, Tensor* output) {
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");

1074 1075 1076 1077
  float scale_d = -1;
  float scale_h = -1;
  float scale_w = -1;

X
xiaoting 已提交
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  } else {
    auto scale_tensor = ctx.Input<Tensor>("Scale");
    auto scale = ctx.Attr<std::vector<float>>("scale");
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      if (scale_data.size() > 1) {
        scale_d = scale_data[0];
        scale_h = scale_data[1];
        scale_w = scale_data[2];
      } else {
        scale_d = scale_data[0];
        scale_h = scale_data[0];
        scale_w = scale_data[0];
      }
      PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
      PADDLE_ENFORCE_EQ(
          scale_h > 0, true,
          platform::errors::InvalidArgument(
              "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_h));
      PADDLE_ENFORCE_EQ(
          scale_d > 0, true,
          platform::errors::InvalidArgument(
              "The scale_d in input 'Scale' Tensor of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_d));
X
xiaoting 已提交
1117 1118 1119 1120 1121 1122 1123
    } else {
      if (scale.size() > 1) {
        scale_d = scale[0];
        scale_h = scale[1];
        scale_w = scale[2];

        PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
            scale_w > 0, true,
            platform::errors::InvalidArgument(
                "The scale_w in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_w));
        PADDLE_ENFORCE_EQ(
            scale_h > 0, true,
            platform::errors::InvalidArgument(
                "The scale_h in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_h));
        PADDLE_ENFORCE_EQ(
            scale_d > 0, true,
            platform::errors::InvalidArgument(
                "The scale_d in Attr(scale) of Operator(interpolate) "
                "should be greater than 0, but received value is %d.",
                scale_d));
X
xiaoting 已提交
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
      }
    }
    if (scale_w > 0. && scale_h > 0. && scale_d > 0.) {
      out_d = static_cast<int>(in_d * scale_d);
      out_h = static_cast<int>(in_h * scale_h);
      out_w = static_cast<int>(in_w * scale_w);
    }
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_d = out_size_data[0];
      out_h = out_size_data[1];
      out_w = out_size_data[2];
    }
  }
  PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
                                  "out_d in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));
  PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
                                  "out_h in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));
  PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
                                  "out_w in Attr(out_shape) of Op(interpolate) "
                                  "should be greater than 0."));

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_d, out_h, out_w};
  } else {
    dim_out = {n, out_d, out_h, out_w, c};
  }

  output->mutable_data<T>(dim_out, ctx.GetPlace());

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
1184 1185 1186
    float new_scale_d = 0.f;
    new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
                                : static_cast<float>(in_d) / out_d;
X
xiaoting 已提交
1187
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
1188
                              : static_cast<float>(new_scale_d);
X
xiaoting 已提交
1189 1190
  }
  if (out_h > 1) {
1191 1192 1193
    float new_scale_h = 0.f;
    new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
                                : static_cast<float>(in_h) / out_h;
X
xiaoting 已提交
1194
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
1195
                              : static_cast<float>(new_scale_h);
X
xiaoting 已提交
1196 1197
  }
  if (out_w > 1) {
1198 1199 1200
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
1201
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
1202
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
1203 1204 1205 1206 1207 1208
  }

  if ("trilinear" == interp_method) {
    TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
                              in_h, in_w, n, c, out_d, out_h, out_w,
                              align_corners, align_mode, data_layout);
1209 1210 1211 1212
  } else if ("nearest" == interp_method) {
    NearestNeighbor3DInterpolate<T>(input, output, ratio_d, ratio_h, ratio_w, n,
                                    c, out_d, out_h, out_w, align_corners,
                                    data_layout);
X
xiaoting 已提交
1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235
  }
}

template <typename T>
static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
                                Tensor* input_grad, const Tensor& output_grad) {
  auto* input = ctx.Input<Tensor>("X");
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_w = ctx.Attr<int>("out_w");
  float scale_w = -1.0;
  auto scale_tensor = ctx.Input<Tensor>("Scale");
  auto scale = ctx.Attr<std::vector<float>>("scale");
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale_w = scale_data[0];
K
Kqnonrime 已提交
1236 1237 1238 1239 1240 1241
    PADDLE_ENFORCE_EQ(
        scale_w > 0, true,
        platform::errors::InvalidArgument(
            "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_w));
X
xiaoting 已提交
1242 1243 1244
  } else {
    if (scale.size() > 0) {
      scale_w = scale[0];
K
Kqnonrime 已提交
1245 1246 1247 1248 1249 1250
      PADDLE_ENFORCE_EQ(
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
X
xiaoting 已提交
1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286
    }
  }
  if (scale_w > 0.) {
    out_w = static_cast<int>(in_w * scale_w);
  }
  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
    out_w = out_size_data[0];
  }
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_w = new_size[0];
  }

  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_w};
  } else {
    dim_grad = {n, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());

  auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
  math::SetConstant<platform::CPUDeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_w = 0.f;
  if (out_w > 1) {
1287 1288 1289
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
1290
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
1291
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327
  }
  if ("linear" == interp_method) {
    LinearInterpolationGrad<T>(output_grad, input_grad, ratio_w, in_w, n, c,
                               out_w, align_corners, align_mode, data_layout);
  }
}

template <typename T>
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
                                Tensor* input_grad, const Tensor& output_grad) {
  auto* input = ctx.Input<Tensor>("X");
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale_h = -1;
  float scale_w = -1;
  auto scale_tensor = ctx.Input<Tensor>("Scale");
  auto scale = ctx.Attr<std::vector<float>>("scale");
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    if (scale_data.size() > 1) {
      scale_h = scale_data[0];
      scale_w = scale_data[1];
    } else {
      scale_w = scale_data[0];
      scale_h = scale_data[0];
    }
    PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
        scale_w > 0, true,
        platform::errors::InvalidArgument(
            "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_w));
    PADDLE_ENFORCE_EQ(
        scale_h > 0, true,
        platform::errors::InvalidArgument(
            "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_h));
X
xiaoting 已提交
1339 1340 1341 1342 1343
  } else {
    if (scale.size() > 1) {
      scale_h = scale[0];
      scale_w = scale[1];
      PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
      PADDLE_ENFORCE_EQ(
          scale_h > 0, true,
          platform::errors::InvalidArgument(
              "The scale_h in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_h));
X
xiaoting 已提交
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394
    }
  }
  if (scale_h > 0. && scale_w > 0.) {
    out_h = static_cast<int>(in_h * scale_h);
    out_w = static_cast<int>(in_w * scale_w);
  }
  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
    out_h = out_size_data[0];
    out_w = out_size_data[1];
  }
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  }

  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_h, in_w};
  } else {
    dim_grad = {n, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());

  auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
  math::SetConstant<platform::CPUDeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
1395 1396 1397
    float new_scale_h = 0.f;
    new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
                                : static_cast<float>(in_h) / out_h;
X
xiaoting 已提交
1398
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
1399
                              : static_cast<float>(new_scale_h);
X
xiaoting 已提交
1400 1401
  }
  if (out_w > 1) {
1402 1403 1404
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
1405
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
1406
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456
  }

  if ("bilinear" == interp_method) {
    BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
                                 in_h, in_w, n, c, out_h, out_w, align_corners,
                                 align_mode, data_layout);
  } else if ("nearest" == interp_method) {
    NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
                                      n, c, out_h, out_w, align_corners,
                                      data_layout);
  } else if ("bicubic" == interp_method) {
    BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
                                in_w, n, c, out_h, out_w, align_corners,
                                data_layout);
  }
}

template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
                                Tensor* input_grad, const Tensor output_grad) {
  auto* input = ctx.Input<Tensor>("X");
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale_d = -1;
  float scale_h = -1;
  float scale_w = -1;
  auto scale_tensor = ctx.Input<Tensor>("Scale");
  auto scale = ctx.Attr<std::vector<float>>("scale");
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    if (scale_data.size() > 1) {
      scale_d = scale_data[0];
      scale_h = scale_data[1];
      scale_w = scale_data[2];
    } else {
      scale_d = scale_data[0];
      scale_h = scale_data[0];
      scale_w = scale_data[0];
    }
    PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
        scale_w > 0, true,
        platform::errors::InvalidArgument(
            "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_w));
    PADDLE_ENFORCE_EQ(
        scale_h > 0, true,
        platform::errors::InvalidArgument(
            "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_h));
    PADDLE_ENFORCE_EQ(
        scale_d > 0, true,
        platform::errors::InvalidArgument(
            "The scale_d in input 'Scale' Tensor of Operator(interpolate) "
            "should be greater than 0, but received value is %d.",
            scale_d));
X
xiaoting 已提交
1474 1475 1476 1477 1478 1479
  } else {
    if (scale.size() > 1) {
      scale_d = scale[0];
      scale_h = scale[1];
      scale_w = scale[2];
      PADDLE_ENFORCE_EQ(
K
Kqnonrime 已提交
1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496
          scale_w > 0, true,
          platform::errors::InvalidArgument(
              "The scale_w in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_w));
      PADDLE_ENFORCE_EQ(
          scale_h > 0, true,
          platform::errors::InvalidArgument(
              "The scale_h in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_h));
      PADDLE_ENFORCE_EQ(
          scale_d > 0, true,
          platform::errors::InvalidArgument(
              "The scale_d in Attr(scale) of Operator(interpolate) "
              "should be greater than 0, but received value is %d.",
              scale_d));
X
xiaoting 已提交
1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539
    }
  }
  if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
    out_d = static_cast<int>(in_d * scale_d);
    out_h = static_cast<int>(in_h * scale_h);
    out_w = static_cast<int>(in_w * scale_w);
  }
  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
    out_d = out_size_data[0];
    out_h = out_size_data[1];
    out_w = out_size_data[2];
  }
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  }

  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_d, in_h, in_w};
  } else {
    dim_grad = {n, in_d, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
  auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
  math::SetConstant<platform::CPUDeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
1540 1541 1542
    float new_scale_d = 0.f;
    new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
                                : static_cast<float>(in_d) / out_d;
X
xiaoting 已提交
1543
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
1544
                              : static_cast<float>(new_scale_d);
X
xiaoting 已提交
1545 1546
  }
  if (out_h > 1) {
1547 1548 1549
    float new_scale_h = 0.f;
    new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
                                : static_cast<float>(in_h) / out_h;
X
xiaoting 已提交
1550
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
1551
                              : static_cast<float>(new_scale_h);
X
xiaoting 已提交
1552 1553
  }
  if (out_w > 1) {
1554 1555 1556
    float new_scale_w = 0.f;
    new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
                                : static_cast<float>(in_w) / out_w;
X
xiaoting 已提交
1557
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
1558
                              : static_cast<float>(new_scale_w);
X
xiaoting 已提交
1559 1560 1561 1562 1563 1564
  }

  if ("trilinear" == interp_method) {
    TrilinearInterpolationGrad<T>(
        output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
        c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
1565 1566 1567 1568
  } else if ("nearest" == interp_method) {
    NearestNeighbor3DInterpolateGrad<T>(output_grad, input_grad, ratio_d,
                                        ratio_h, ratio_w, n, c, out_d, out_h,
                                        out_w, align_corners, data_layout);
X
xiaoting 已提交
1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609
  }
}

template <typename T>
class InterpolateV2Kernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");

    auto input_dims = input->dims();
    if (input_dims.size() == 3) {  // 1D interpolation
      Interpolate1DCPUFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 4) {  // 2D interpolation
      Interpolate2DCPUFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 5) {  // 3D interpolation
      Interpolate3DCPUFwd<T>(ctx, *input, output);
    }
  }
};

template <typename T>
class InterpolateV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));

    auto output_grad_dims = output_grad->dims();
    if (output_grad_dims.size() == 3) {  // 1D interpolation grad
      Interpolate1DCPUBwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 4) {  // 2D interpolation grad
      Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 5) {  // 3D interpolation grad
      Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
    }
  }
};

}  // namespace operators
}  // namespace paddle