interpolate_op.h 40.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
X
xiaoting 已提交
13
#include <algorithm>
14
#include <string>
15
#include <vector>
16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
X
xiaoting 已提交
18
#include "paddle/fluid/platform/hostdevice.h"
19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
27
using DataLayout = framework::DataLayout;
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
inline std::vector<int> get_new_shape(
    const std::vector<const Tensor*>& list_new_shape_tensor) {
  // get tensor from
  std::vector<int> vec_new_shape;
  for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
    auto tensor = list_new_shape_tensor[i];
    PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
                      "shape of dim tensor should be [1]");
    if (platform::is_gpu_place(tensor->place())) {
      framework::Tensor temp;
      TensorCopySync(*tensor, platform::CPUPlace(), &temp);

      vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
    } else {
      vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
    }
  }

  return vec_new_shape;
}

template <typename T>
inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
  std::vector<T> vec_new_data;
  auto* new_data = new_data_tensor->data<T>();
  framework::Tensor cpu_starts_tensor;
  if (platform::is_gpu_place(new_data_tensor->place())) {
    TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
    new_data = cpu_starts_tensor.data<T>();
  }
  vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
  return vec_new_data;
}

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
inline void ExtractNCDWH(const framework::DDim& dims,
                         const DataLayout& data_layout, int* N, int* C, int* D,
                         int* H, int* W) {
  *N = dims[0];
  if (dims.size() == 4) {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
    *D = 1;
    *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
  } else {
    *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
    *D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
    *H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
    *W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
  }
}

80 81 82 83
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
                                       const float ratio_h, const float ratio_w,
                                       const int n, const int c,
84
                                       const int out_h, const int out_w,
85 86
                                       const bool align_corners,
                                       const DataLayout& data_layout) {
87 88 89
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
  for (int k = 0; k < out_h; k++) {  // loop for images
90 91
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);
92 93

    for (int l = 0; l < out_w; l++) {
94 95
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);
96 97 98

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
99 100 101 102 103
          if (data_layout == DataLayout::kNCHW) {
            output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
          } else {
            output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
          }
104 105 106 107 108 109 110 111 112 113
        }
      }
    }
  }
}

template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
                                  const float ratio_h, const float ratio_w,
                                  const int in_h, const int in_w, const int n,
114 115
                                  const int c, const int out_h, const int out_w,
                                  const bool align_corners,
116 117
                                  const bool align_mode,
                                  const DataLayout data_layout) {
118 119
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);
T
tink2123 已提交
120
  bool align_flag = (align_mode == 0 && !align_corners);
121 122 123 124 125 126 127 128 129 130 131

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
T
tink2123 已提交
132 133
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
T
tink2123 已提交
134
    y_n = (y_n > 0) ? y_n : 0;
135
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
136 137 138
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
139
    float d_s = 1.f - d_n;
140 141 142 143 144 145 146
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }
147

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
163 164 165
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
166 167 168 169 170 171 172 173
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }
174

175 176 177 178 179 180 181
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(4)
#endif
  for (int i = 0; i < n; i++) {          // loop for batches
    for (int j = 0; j < c; j++) {        // loop for channels
      for (int k = 0; k < out_h; k++) {  // loop for images
        for (int l = 0; l < out_w; l++) {
182
          // bilinear interpolation
183 184 185
          T out_t;
          if (data_layout == DataLayout::kNCHW) {
            out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
186 187 188
                    input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] +
                    input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] +
                    input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l];
189 190 191 192 193 194 195 196 197
            output_t(i, j, k, l) = out_t;

          } else {
            out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] +
                    input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] +
                    input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] +
                    input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l];
            output_t(i, k, l, j) = out_t;
          }
198 199 200 201 202 203
        }
      }
    }
  }
}

K
Kaipeng Deng 已提交
204 205 206 207 208
template <typename T>
static void TrilinearInterpolation(
    const Tensor& input, Tensor* output, const float ratio_d,
    const float ratio_h, const float ratio_w, const int in_d, const int in_h,
    const int in_w, const int n, const int c, const int out_d, const int out_h,
209 210
    const int out_w, const bool align_corners, const bool align_mode,
    const DataLayout& data_layout) {
K
Kaipeng Deng 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  auto input_t = EigenTensor<T, 5>::From(input);
  auto output_t = EigenTensor<T, 5>::From(*output);
  bool align_flag = (align_mode == 0 && !align_corners);

  std::vector<int> vt_f, vt_b;
  std::vector<float> vd_f, vd_b;
  vt_f.reserve(out_d);
  vt_b.reserve(out_d);
  vd_f.reserve(out_d);
  vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int j = 0; j < out_d; j++) {
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;
    {
      vt_f[j] = t_f;
      vt_b[j] = t_b;
      vd_f[j] = d_f;
      vd_b[j] = d_b;
    }
  }

  std::vector<int> vy_n, vy_s;
  std::vector<float> vd_n, vd_s;
  vy_n.reserve(out_h);
  vy_s.reserve(out_h);
  vd_n.reserve(out_h);
  vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int k = 0; k < out_h; k++) {
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
    y_n = (y_n > 0) ? y_n : 0;
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
    float d_s = 1.f - d_n;
    {
      vy_n[k] = y_n;
      vy_s[k] = y_s;
      vd_n[k] = d_n;
      vd_s[k] = d_s;
    }
  }

  std::vector<int> vx_w, vx_e;
  std::vector<float> vd_w, vd_e;
  vx_w.reserve(out_w);
  vx_e.reserve(out_w);
  vd_w.reserve(out_w);
  vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
  for (int l = 0; l < out_w; l++) {
    int x_w = (align_mode == 0 && !align_corners)
                  ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                  : static_cast<int>(ratio_w * l);
    x_w = (x_w > 0) ? x_w : 0;
    int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
    float idx_src_x = ratio_w * (l + 0.5) - 0.5;
    idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
    float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
    float d_e = 1.f - d_w;
    {
      vx_w[l] = x_w;
      vx_e[l] = x_e;
      vd_w[l] = d_w;
      vd_e[l] = d_e;
    }
  }

#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
  for (int b = 0; b < n; b++) {          // loop for batches
    for (int i = 0; i < c; i++) {        // loop for channels
      for (int j = 0; j < out_d; j++) {  // loop for D, H, W
        for (int k = 0; k < out_h; k++) {
          for (int l = 0; l < out_w; l++) {
            // trilinear interpolation
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            if (data_layout == DataLayout::kNCHW) {
              T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, i, j, k, l) = out_t;
            } else {
              T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] *
                            vd_n[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] *
                            vd_s[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] *
                            vd_s[k] * vd_w[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] *
                            vd_n[k] * vd_e[l] +
                        input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] *
                            vd_n[k] * vd_w[l];
              output_t(b, j, k, l, i) = out_t;
            }
K
Kaipeng Deng 已提交
340 341 342 343 344 345 346
          }
        }
      }
    }
  }
}

X
xiaoting 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
  return ((A + 2) * x - (A + 3)) * x * x + 1;
}

template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}

template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
  T A = -0.75;

  T x1 = t;
  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
  coeffs[1] = cubic_convolution1<T>(x1, A);

  // opposite coefficients
  T x2 = 1.0 - t;
  coeffs[2] = cubic_convolution1<T>(x2, A);
  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}

template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
  T coeffs[4];
  get_cubic_upsample_coefficients<T>(coeffs, t);

  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
                                 const float ratio_h, const float ratio_w,
                                 const int in_h, const int in_w, const int n,
                                 const int c, const int out_h, const int out_w,
                                 const bool align_corners,
                                 const DataLayout data_layout) {
  auto input_t = EigenTensor<T, 4>::From(input);
  auto output_t = EigenTensor<T, 4>::From(*output);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
392
    int input_y = floorf(y_n);
X
xiaoting 已提交
393 394 395 396 397
    const T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
398
      int input_x = floorf(x_n);
X
xiaoting 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
      const T x_t = x_n - input_x;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          T coefficients[4];
          // interp 4 times in x direction
          for (int ii = 0; ii < 4; ii++) {
            int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
                                    static_cast<int>(0));
            int access_x_0 =
                std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
            int access_x_1 =
                std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
            int access_x_2 =
                std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
            int access_x_3 =
                std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
            if (data_layout == DataLayout::kNCHW) {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, j, access_y, access_x_0),
                                  input_t(i, j, access_y, access_x_1),
                                  input_t(i, j, access_y, access_x_2),
                                  input_t(i, j, access_y, access_x_3), x_t);
            } else {
              coefficients[ii] =
                  cubic_interp<T>(input_t(i, access_y, access_x_0, j),
                                  input_t(i, access_y, access_x_1, j),
                                  input_t(i, access_y, access_x_2, j),
                                  input_t(i, access_y, access_x_3, j), x_t);
            }
          }

          // interp y direction
          if (data_layout == DataLayout::kNCHW) {
            output_t(i, j, k, l) =
                cubic_interp<T>(coefficients[0], coefficients[1],
                                coefficients[2], coefficients[3], y_t);
          } else {
            output_t(i, k, l, j) =
                cubic_interp<T>(coefficients[0], coefficients[1],
                                coefficients[2], coefficients[3], y_t);
          }
        }
      }
    }
  }
}

447
template <typename T>
448 449 450
static void NearestNeighborInterpolateGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
    const float ratio_w, const int n, const int c, const int out_h,
451
    const int out_w, const bool align_corners, const DataLayout data_layout) {
452 453
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
454

455
  for (int k = 0; k < out_h; k++) {  // loop for images
456 457
    int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
                               : static_cast<int>(ratio_h * k);
458 459

    for (int l = 0; l < out_w; l++) {
460 461
      int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
                                 : static_cast<int>(ratio_w * l);
462 463 464

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
465 466 467 468 469
          if (data_layout == DataLayout::kNCHW) {
            input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
          } else {
            input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
          }
470 471 472 473 474 475 476
        }
      }
    }
  }
}

template <typename T>
477 478 479 480 481
static void BilinearInterpolationGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
    const float ratio_w, const int in_h, const int in_w, const int n,
    const int c, const int out_h, const int out_w, const bool align_corners,
    const int align_mode, const DataLayout data_layout) {
482 483
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
T
tink2123 已提交
484
  bool align_flag = (align_mode == 0 && !align_corners);
485
  for (int k = 0; k < out_h; k++) {  // loop for images
T
tink2123 已提交
486 487
    int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * k);
T
tink2123 已提交
488
    y_n = (y_n > 0) ? y_n : 0;
489
    int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
490 491 492
    float idx_src_y = ratio_h * (k + 0.5) - 0.5;
    idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
    float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
493 494 495
    float d_s = 1.f - d_n;

    for (int l = 0; l < out_w; l++) {
T
tink2123 已提交
496 497
      int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                           : static_cast<int>(ratio_w * l);
T
tink2123 已提交
498
      x_w = (x_w > 0) ? x_w : 0;
499
      int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
500 501 502
      float idx_src_x = ratio_w * (l + 0.5) - 0.5;
      idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
      float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
503 504 505 506 507
      float d_e = 1.f - d_w;

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bilinear interpolation grad
508 509 510 511 512 513 514 515 516 517 518 519 520
          if (data_layout == DataLayout::kNCHW) {
            const T grad = output_grad_t(i, j, k, l);
            input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
          } else {
            const T grad = output_grad_t(i, k, l, j);
            input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
            input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
            input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
            input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
          }
521 522 523 524 525
        }
      }
    }
  }
}
K
Kaipeng Deng 已提交
526

527
template <typename T>
K
Kaipeng Deng 已提交
528 529 530 531
static void TrilinearInterpolationGrad(
    const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
    const float ratio_h, const float ratio_w, const int in_d, const int in_h,
    const int in_w, const int n, const int c, const int out_d, const int out_h,
532 533
    const int out_w, const bool align_corners, const int align_mode,
    const DataLayout data_layout) {
K
Kaipeng Deng 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
  auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
  bool align_flag = (align_mode == 0 && !align_corners);
  for (int j = 0; j < out_d; j++) {  // loop for D
    int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * j);
    t_f = (t_f > 0) ? t_f : 0;
    int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
    float idx_src_t = ratio_d * (j + 0.5) - 0.5;
    idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
    float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
    float d_b = 1.f - d_f;

    for (int k = 0; k < out_h; k++) {  // loop for H
      int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
                           : static_cast<int>(ratio_h * k);
      y_n = (y_n > 0) ? y_n : 0;
      int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
      float idx_src_y = ratio_h * (k + 0.5) - 0.5;
      idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
      float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
      float d_s = 1.f - d_n;

      for (int l = 0; l < out_w; l++) {  // loop for W
        int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
                             : static_cast<int>(ratio_w * l);
        x_w = (x_w > 0) ? x_w : 0;
        int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
        float idx_src_x = ratio_w * (l + 0.5) - 0.5;
        idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
        float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
        float d_e = 1.f - d_w;

        for (int b = 0; b < n; b++) {    // loop for batches
          for (int i = 0; i < c; i++) {  // loop for channels
            // trilinear interpolation grad
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
            if (data_layout == DataLayout::kNCHW) {
              const T grad = output_grad_t(b, i, j, k, l);
              input_grad_t(b, i, t_f, y_n, x_w) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, i, t_f, y_n, x_e) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, i, t_f, y_s, x_w) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, i, t_f, y_s, x_e) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, i, t_b, y_n, x_w) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, i, t_b, y_n, x_e) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, i, t_b, y_s, x_w) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, i, t_b, y_s, x_e) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            } else {
              const T grad = output_grad_t(b, j, k, l, i);
              input_grad_t(b, t_f, y_n, x_w, i) +=
                  static_cast<T>(grad * d_b * d_s * d_e);
              input_grad_t(b, t_f, y_n, x_e, i) +=
                  static_cast<T>(grad * d_b * d_s * d_w);
              input_grad_t(b, t_f, y_s, x_w, i) +=
                  static_cast<T>(grad * d_b * d_n * d_e);
              input_grad_t(b, t_f, y_s, x_e, i) +=
                  static_cast<T>(grad * d_b * d_n * d_w);
              input_grad_t(b, t_b, y_n, x_w, i) +=
                  static_cast<T>(grad * d_f * d_s * d_e);
              input_grad_t(b, t_b, y_n, x_e, i) +=
                  static_cast<T>(grad * d_f * d_s * d_w);
              input_grad_t(b, t_b, y_s, x_w, i) +=
                  static_cast<T>(grad * d_f * d_n * d_e);
              input_grad_t(b, t_b, y_s, x_e, i) +=
                  static_cast<T>(grad * d_f * d_n * d_w);
            }
K
Kaipeng Deng 已提交
607 608 609 610 611 612
          }
        }
      }
    }
  }
}
613

X
xiaoting 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627
template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
                                     Tensor* input_grad, const float ratio_h,
                                     const float ratio_w, const int in_h,
                                     const int in_w, const int n, const int c,
                                     const int out_h, const int out_w,
                                     const bool align_corners,
                                     const DataLayout data_layout) {
  auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
  auto output_grad_t = EigenTensor<T, 4>::From(output_grad);

  for (int k = 0; k < out_h; k++) {  // loop for images
    T y_n = align_corners ? static_cast<T>(ratio_h * k)
                          : static_cast<T>(ratio_h * (k + 0.5) - 0.5);
628
    int input_y = floorf(y_n);
X
xiaoting 已提交
629 630 631 632 633
    T y_t = y_n - input_y;

    for (int l = 0; l < out_w; l++) {
      T x_n = align_corners ? static_cast<T>(ratio_w * l)
                            : static_cast<T>(ratio_w * (l + 0.5) - 0.5);
634
      int input_x = floorf(x_n);
X
xiaoting 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
      T x_t = x_n - input_x;

      T x_coeffs[4];
      T y_coeffs[4];

      get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
      get_cubic_upsample_coefficients<T>(y_coeffs, y_t);

      for (int i = 0; i < n; i++) {    // loop for batches
        for (int j = 0; j < c; j++) {  // loop for channels
          // bicubic interpolation grad
          for (int ii = 0; ii < 4; ii++) {
            for (int jj = 0; jj < 4; jj++) {
              int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
                                      static_cast<int>(0));
              int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
                                      static_cast<int>(0));
              if (data_layout == DataLayout::kNCHW) {
                T grad = output_grad_t(i, j, k, l);
                input_grad_t(i, j, access_y, access_x) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              } else {
                T grad = output_grad_t(i, k, l, j);
                input_grad_t(i, access_y, access_x, j) +=
                    grad * y_coeffs[jj] * x_coeffs[ii];
              }
            }
          }
        }
      }
    }
  }
}

K
Kaipeng Deng 已提交
669 670 671
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
                                const Tensor& input, Tensor* output) {
672 673 674 675
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
676 677 678 679 680 681 682

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
D
dengkaipeng 已提交
683

684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  } else {
    float scale;
    auto scale_tensor = ctx.Input<Tensor>("Scale");
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale = scale_data[0];
    } else {
      scale = ctx.Attr<float>("scale");
    }
    if (scale > 0) {
      out_h = static_cast<int>(in_h * scale);
      out_w = static_cast<int>(in_w * scale);
    }
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_h = out_size_data[0];
      out_w = out_size_data[1];
    }
K
Kaipeng Deng 已提交
709
  }
710 711 712 713 714 715
  PADDLE_ENFORCE_GT(
      out_h, 0,
      "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
  PADDLE_ENFORCE_GT(
      out_w, 0,
      "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
716 717 718 719 720 721 722
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_h, out_w};
  } else {
    dim_out = {n, out_h, out_w, c};
  }
  output->mutable_data<T>(dim_out, ctx.GetPlace());
D
dengkaipeng 已提交
723

K
Kaipeng Deng 已提交
724 725 726 727
  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }
728

K
Kaipeng Deng 已提交
729 730 731 732 733 734 735 736 737 738
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
T
tink2123 已提交
739

K
Kaipeng Deng 已提交
740 741
  if ("bilinear" == interp_method) {
    BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
742 743
                             out_h, out_w, align_corners, align_mode,
                             data_layout);
K
Kaipeng Deng 已提交
744 745
  } else if ("nearest" == interp_method) {
    NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
746
                                  out_w, align_corners, data_layout);
X
xiaoting 已提交
747 748 749
  } else if ("bicubic" == interp_method) {
    BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
                            out_h, out_w, align_corners, data_layout);
K
Kaipeng Deng 已提交
750 751
  }
}
752

K
Kaipeng Deng 已提交
753 754 755
template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
                                const Tensor& input, Tensor* output) {
756 757 758 759
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
760 761 762 763 764 765 766 767 768

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");

769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  } else {
    float scale;
    auto scale_tensor = ctx.Input<Tensor>("Scale");
    if (scale_tensor != nullptr) {
      auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
      scale = scale_data[0];
    } else {
      scale = ctx.Attr<float>("scale");
    }
    if (scale > 0) {
      out_d = static_cast<int>(in_d * scale);
      out_h = static_cast<int>(in_h * scale);
      out_w = static_cast<int>(in_w * scale);
    }
    auto out_size = ctx.Input<Tensor>("OutSize");
    if (out_size != nullptr) {
      auto out_size_data = get_new_data_from_tensor<int>(out_size);
      out_d = out_size_data[0];
      out_h = out_size_data[1];
      out_w = out_size_data[2];
    }
K
Kaipeng Deng 已提交
797
  }
798 799 800 801 802 803 804 805 806
  PADDLE_ENFORCE_GT(
      out_d, 0,
      "out_d in Attr(out_shape) of Op(interpolate) should be greater than 0.");
  PADDLE_ENFORCE_GT(
      out_h, 0,
      "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
  PADDLE_ENFORCE_GT(
      out_w, 0,
      "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
807 808 809 810 811 812 813 814 815

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {n, c, out_d, out_h, out_w};
  } else {
    dim_out = {n, out_d, out_h, out_w, c};
  }

  output->mutable_data<T>(dim_out, ctx.GetPlace());
K
Kaipeng Deng 已提交
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
836
  }
K
Kaipeng Deng 已提交
837 838 839 840

  if ("trilinear" == interp_method) {
    TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
                              in_h, in_w, n, c, out_d, out_h, out_w,
841
                              align_corners, align_mode, data_layout);
K
Kaipeng Deng 已提交
842 843
  }
}
844 845

template <typename T>
K
Kaipeng Deng 已提交
846 847 848
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
                                Tensor* input_grad, const Tensor& output_grad) {
  auto* input = ctx.Input<Tensor>("X");
849 850 851 852
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
853 854 855 856 857 858 859

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
860 861 862 863 864 865 866 867
  float scale;
  auto scale_tensor = ctx.Input<Tensor>("Scale");
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale = scale_data[0];
  } else {
    scale = ctx.Attr<float>("scale");
  }
K
Kaipeng Deng 已提交
868 869 870 871 872 873
  if (scale > 0) {
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }
  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
874
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
K
Kaipeng Deng 已提交
875 876 877
    out_h = out_size_data[0];
    out_w = out_size_data[1];
  }
878 879 880 881 882 883 884
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_h = new_size[0];
    out_w = new_size[1];
  }
D
dengkaipeng 已提交
885

886 887 888 889 890 891 892 893
  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_h, in_w};
  } else {
    dim_grad = {n, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());

K
Kaipeng Deng 已提交
894 895 896
  auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
  math::SetConstant<platform::CPUDeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));
D
dengkaipeng 已提交
897

K
Kaipeng Deng 已提交
898 899 900 901
  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }
D
dengkaipeng 已提交
902

K
Kaipeng Deng 已提交
903 904 905 906 907 908 909 910 911 912 913 914 915 916
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  if ("bilinear" == interp_method) {
    BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
                                 in_h, in_w, n, c, out_h, out_w, align_corners,
917
                                 align_mode, data_layout);
K
Kaipeng Deng 已提交
918 919
  } else if ("nearest" == interp_method) {
    NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
920 921
                                      n, c, out_h, out_w, align_corners,
                                      data_layout);
X
xiaoting 已提交
922 923 924 925
  } else if ("bicubic" == interp_method) {
    BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
                                in_w, n, c, out_h, out_w, align_corners,
                                data_layout);
K
Kaipeng Deng 已提交
926 927
  }
}
D
dengkaipeng 已提交
928

K
Kaipeng Deng 已提交
929 930 931 932
template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
                                Tensor* input_grad, const Tensor output_grad) {
  auto* input = ctx.Input<Tensor>("X");
933 934 935 936
  const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
  const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
  int n, c, in_d, in_h, in_w;
  ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
K
Kaipeng Deng 已提交
937 938 939 940 941 942 943 944

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
945 946 947 948 949 950 951 952
  float scale;
  auto scale_tensor = ctx.Input<Tensor>("Scale");
  if (scale_tensor != nullptr) {
    auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
    scale = scale_data[0];
  } else {
    scale = ctx.Attr<float>("scale");
  }
K
Kaipeng Deng 已提交
953 954 955 956 957 958 959
  if (scale > 0) {
    out_d = static_cast<int>(in_d * scale);
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }
  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
960
    auto out_size_data = get_new_data_from_tensor<int>(out_size);
K
Kaipeng Deng 已提交
961 962 963 964
    out_d = out_size_data[0];
    out_h = out_size_data[1];
    out_w = out_size_data[2];
  }
965 966 967 968 969 970 971 972
  auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
  if (list_new_size_tensor.size() > 0) {
    // have size tensor
    auto new_size = get_new_shape(list_new_size_tensor);
    out_d = new_size[0];
    out_h = new_size[1];
    out_w = new_size[2];
  }
973

974 975 976 977 978 979 980
  framework::DDim dim_grad;
  if (data_layout == DataLayout::kNCHW) {
    dim_grad = {n, c, in_d, in_h, in_w};
  } else {
    dim_grad = {n, in_d, in_h, in_w, c};
  }
  input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
K
Kaipeng Deng 已提交
981 982 983 984 985 986 987 988
  auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
  math::SetConstant<platform::CPUDeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }
989

K
Kaipeng Deng 已提交
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }
T
tink2123 已提交
1005

K
Kaipeng Deng 已提交
1006
  if ("trilinear" == interp_method) {
1007 1008 1009
    TrilinearInterpolationGrad<T>(
        output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
        c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
K
Kaipeng Deng 已提交
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
  }
}

template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");

    auto input_dims = input->dims();
    if (input_dims.size() == 4) {  // 2D interpolation
      Interpolate2DCPUFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 5) {  // 3D interpolation
      Interpolate3DCPUFwd<T>(ctx, *input, output);
T
tink2123 已提交
1025
    }
K
Kaipeng Deng 已提交
1026 1027 1028 1029 1030 1031 1032 1033 1034
  }
};

template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
1035

K
Kaipeng Deng 已提交
1036 1037 1038 1039 1040
    auto output_grad_dims = output_grad->dims();
    if (output_grad_dims.size() == 4) {  // 2D interpolation grad
      Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 5) {  // 3D interpolation grad
      Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
1041 1042 1043 1044 1045 1046
    }
  }
};

}  // namespace operators
}  // namespace paddle