interpolate_op.cu 28.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using framework::Tensor;

template <typename T>
__global__ void KeNearestNeighborInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
26 27
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners) {
28 29
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
30 31
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
32 33 34 35 36 37 38
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
39 40 41
    int in_img_idy = (align_corners)
                         ? static_cast<int>(ratio_h * out_img_idy + 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
42 43

    int out_img_idx = tid % out_img_w;
44 45 46
    int in_img_idx = (align_corners)
                         ? static_cast<int>(ratio_w * out_img_idx + 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
47 48 49 50 51 52 53 54 55 56 57

    out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
                  in_img_idy * in_img_w + in_img_idx];
  }
}

template <typename T>
__global__ void KeNearestNeighborInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
58 59
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners) {
60 61
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
62 63
  int stride = blockDim.x * gridDim.x;
  for (; tid < nthreads; tid += stride) {
64 65 66 67 68 69 70
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
71 72 73
    int in_img_idy = (align_corners)
                         ? static_cast<int>(ratio_h * out_img_idy + 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
74 75

    int out_img_idx = tid % out_img_w;
76 77 78
    int in_img_idx = (align_corners)
                         ? static_cast<int>(ratio_w * out_img_idx + 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
79 80 81 82 83 84 85 86 87 88 89 90 91

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T out_pos = out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(in_pos, out_pos);
  }
}

template <typename T>
__global__ void KeBilinearInterpFw(
    const T* in, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
92 93
    const size_t num_channels, const float ratio_h, const float ratio_w,
    const bool align_corners, const int align_mode) {
94 95
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
96
  int stride = blockDim.x * gridDim.x;
T
tink2123 已提交
97
  bool align_flag = (align_mode == 0 && !align_corners);
98
  for (; tid < nthreads; tid += stride) {
99 100 101 102 103 104 105
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
T
tink2123 已提交
106
    int in_img_idy = align_flag
107 108
                         ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
T
tink2123 已提交
109
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
110
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
111 112 113 114
    T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
    src_h = (src_h > 0) ? src_h : 0;
    T h1lambda =
        align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
115 116 117
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
T
tink2123 已提交
118
    int in_img_idx = align_flag
119 120
                         ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
T
tink2123 已提交
121
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
122
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
123 124 125 126
    T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
    src_w = (src_w > 0) ? src_w : 0;
    T w1lambda =
        align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    T w2lambda = 1.f - w1lambda;

    const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                          in_img_idy * in_img_w + in_img_idx];

    // bilinear interpolation
    out[out_id_h * output_w + out_id_w] =
        h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
        h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
                    w1lambda * in_pos[h_id * in_img_w + w_id]);
  }
}

template <typename T>
__global__ void KeBilinearInterpBw(
    T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
    const size_t input_w, const T* out, const size_t out_img_h,
    const size_t out_img_w, const size_t output_h, const size_t output_w,
145 146
    const size_t num_channels, const T ratio_h, const T ratio_w,
    const bool align_corners, const int align_mode) {
147 148
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
149
  int stride = blockDim.x * gridDim.x;
T
tink2123 已提交
150
  bool align_flag = (align_mode == 0 && !align_corners);
151
  for (; tid < nthreads; tid += stride) {
152 153 154 155 156 157 158
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;

    int out_img_idy = (out_id_w % out_img_size) / out_img_w;
T
tink2123 已提交
159 160
    int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5
                                : ratio_h * out_img_idy;
T
tink2123 已提交
161
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
162
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
163 164 165 166
    T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
    src_h = (src_h > 0) ? src_h : 0;
    T h1lambda =
        align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
167

168 169 170
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
T
tink2123 已提交
171 172
    int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
                                : ratio_w * out_img_idx;
T
tink2123 已提交
173
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
174
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
175 176 177 178
    T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
    src_w = (src_w > 0) ? src_w : 0;
    T w1lambda =
        align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    T w2lambda = 1.f - w1lambda;

    T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
                    in_img_idy * in_img_w + in_img_idx];
    const T* out_pos = &out[out_id_h * output_w + out_id_w];
    platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
                            h1lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
                            h1lambda * w1lambda * out_pos[0]);
  }
}

template <typename T>
K
Kaipeng Deng 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
__global__ void KeTrilinearInterpFw(
    const T* in, const size_t in_img_d, const size_t in_img_h,
    const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
    const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
    const size_t output_h, const size_t output_w, const size_t num_channels,
    const float ratio_d, const float ratio_h, const float ratio_w,
    const bool align_corners, const int align_mode) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  bool align_flag = (align_mode == 0 && !align_corners);
  for (; tid < nthreads; tid += stride) {
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;
211

K
Kaipeng Deng 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
    int in_img_idt = align_flag
                         ? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * out_img_idt);
    in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
    int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
    T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
    src_d = (src_d > 0) ? src_d : 0;
    T d1lambda =
        align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
    T d2lambda = 1.f - d1lambda;

    int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
    int in_img_idy = align_flag
                         ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
    T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
    src_h = (src_h > 0) ? src_h : 0;
    T h1lambda =
        align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
    T h2lambda = 1.f - h1lambda;
D
dengkaipeng 已提交
235

K
Kaipeng Deng 已提交
236 237 238 239 240 241 242 243 244 245 246
    int out_img_idx = tid % out_img_w;
    int in_img_idx = align_flag
                         ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
    T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
    src_w = (src_w > 0) ? src_w : 0;
    T w1lambda =
        align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
    T w2lambda = 1.f - w1lambda;
D
dengkaipeng 已提交
247

K
Kaipeng Deng 已提交
248 249 250 251 252 253
    int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
                      (in_img_idt * in_img_h + in_img_idy) * in_img_w +
                      in_img_idx;
    const T* in_pos1 = &in[in_pos1_idx];
    int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
    const T* in_pos2 = &in[in_pos2_idx];
D
dengkaipeng 已提交
254

K
Kaipeng Deng 已提交
255 256 257 258 259 260 261 262 263 264 265 266
    // trilinear interpolation
    out[out_id_h * output_w + out_id_w] =
        d2lambda *
            (h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
             h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
                         w1lambda * in_pos1[h_id * in_img_w + w_id])) +
        d1lambda *
            (h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
             h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
                         w1lambda * in_pos2[h_id * in_img_w + w_id]));
  }
}
267

K
Kaipeng Deng 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
template <typename T>
__global__ void KeTrilinearInterpBw(
    T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
    const size_t input_h, const size_t input_w, const T* out,
    const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
    const size_t output_h, const size_t output_w, const size_t num_channels,
    const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners,
    const int align_mode) {
  int nthreads = output_h * output_w;
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  bool align_flag = (align_mode == 0 && !align_corners);
  for (; tid < nthreads; tid += stride) {
    int out_id_h = tid / output_w;
    int out_id_w = tid % output_w;
    int in_img_size = input_w / num_channels;
    int out_img_size = output_w / num_channels;
    int channel_id = out_id_w / out_img_size;
286

K
Kaipeng Deng 已提交
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
    int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
    int in_img_idt = align_flag
                         ? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
                         : static_cast<int>(ratio_d * out_img_idt);
    in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
    int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
    T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
    src_d = (src_d > 0) ? src_d : 0;
    T d1lambda =
        align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
    T d2lambda = 1.f - d1lambda;

    int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
    int in_img_idy = align_flag
                         ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
                         : static_cast<int>(ratio_h * out_img_idy);
    in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
    int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
    T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
    src_h = (src_h > 0) ? src_h : 0;
    T h1lambda =
        align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
    T h2lambda = 1.f - h1lambda;

    int out_img_idx = tid % out_img_w;
    int in_img_idx = align_flag
                         ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
                         : static_cast<int>(ratio_w * out_img_idx);
    in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
    int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
    T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
    src_w = (src_w > 0) ? src_w : 0;
    T w1lambda =
        align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
    T w2lambda = 1.f - w1lambda;
322

K
Kaipeng Deng 已提交
323 324 325 326 327 328
    int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
                      (in_img_idt * in_img_h + in_img_idy) * in_img_w +
                      in_img_idx;
    T* in_pos1 = &in[in_pos1_idx];
    int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
    T* in_pos2 = &in[in_pos2_idx];
329

K
Kaipeng Deng 已提交
330
    const T* out_pos = &out[out_id_h * output_w + out_id_w];
331

K
Kaipeng Deng 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    // trilinear interpolation grad
    platform::CudaAtomicAdd(&in_pos1[0],
                            d2lambda * h2lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos1[w_id],
                            d2lambda * h2lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
                            d2lambda * h1lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
                            d2lambda * h1lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos2[0],
                            d1lambda * h2lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos2[w_id],
                            d1lambda * h2lambda * w1lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
                            d1lambda * h1lambda * w2lambda * out_pos[0]);
    platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
                            d1lambda * h1lambda * w1lambda * out_pos[0]);
  }
}
351

K
Kaipeng Deng 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
                                 const Tensor& input, Tensor* output) {
  auto* input_data = input.data<T>();

  const int n = input.dims()[0];
  const int c = input.dims()[1];
  const int in_h = input.dims()[2];
  const int in_w = input.dims()[3];

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale = ctx.Attr<float>("scale");
  if (scale > 0) {
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }

  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    Tensor sizes;
    framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
    auto size_data = sizes.data<int>();
    out_h = size_data[0];
    out_w = size_data[1];
  }

  auto output_data =
      output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());

  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  int in_hw = in_h * in_w;
  int out_hw = out_h * out_w;
  int in_chw = c * in_hw;
  int out_chw = c * out_hw;

  int pixelNum = n * out_chw;
  int grid_dim = (pixelNum + 512 - 1) / 512;
  grid_dim = grid_dim > 8 ? 8 : grid_dim;

  if ("nearest" == interp_method) {
    KeNearestNeighborInterpFw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
        out_chw, c, ratio_h, ratio_w, align_corners);
  } else if ("bilinear" == interp_method) {
    KeBilinearInterpFw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
        out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
  }
}

template <typename T>
static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
                                 const Tensor& input, Tensor* output) {
  auto* input_data = input.data<T>();

  const int n = input.dims()[0];
  const int c = input.dims()[1];
  const int in_d = input.dims()[2];
  const int in_h = input.dims()[3];
  const int in_w = input.dims()[4];

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale = ctx.Attr<float>("scale");
  if (scale > 0) {
    out_d = static_cast<int>(in_d * scale);
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }

  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    Tensor sizes;
    framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
    auto size_data = sizes.data<int>();
    out_d = size_data[0];
    out_h = size_data[1];
    out_w = size_data[2];
  }

  auto output_data =
      output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(input, ctx.GetPlace(), output);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  int in_dhw = in_d * in_h * in_w;
  int out_dhw = out_d * out_h * out_w;
  int in_cdhw = c * in_dhw;
  int out_cdhw = c * out_dhw;

  int pixelNum = n * out_cdhw;
  int grid_dim = (pixelNum + 512 - 1) / 512;
  grid_dim = grid_dim > 8 ? 8 : grid_dim;

  if ("trilinear" == interp_method) {
    KeTrilinearInterpFw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
        out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
        align_mode);
  }
}

template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
                                 Tensor* input_grad, const Tensor output_grad) {
  auto* input = ctx.Input<Tensor>("X");
  const int n = input->dims()[0];
  const int c = input->dims()[1];
  const int in_h = input->dims()[2];
  const int in_w = input->dims()[3];

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale = ctx.Attr<float>("scale");
  if (scale > 0) {
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }

  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    Tensor sizes;
    framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
    auto size_data = sizes.data<int>();
    out_h = size_data[0];
    out_w = size_data[1];
  }

  auto* output_grad_data = output_grad.data<T>();
  auto* input_grad_data =
      input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
  auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
  math::SetConstant<platform::CUDADeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  int in_hw = in_h * in_w;
  int out_hw = out_h * out_w;
  int in_chw = c * in_hw;
  int out_chw = c * out_hw;

  int pixelNum = n * out_chw;
  int grid_dim = (pixelNum + 512 - 1) / 512;
  grid_dim = grid_dim > 8 ? 8 : grid_dim;

  if ("nearest" == interp_method) {
    KeNearestNeighborInterpBw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
        n, out_chw, c, ratio_h, ratio_w, align_corners);
  } else if ("bilinear" == interp_method) {
    KeBilinearInterpBw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
        n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
  }
}

template <typename T>
static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
                                 Tensor* input_grad,
                                 const Tensor& output_grad) {
  auto* input = ctx.Input<Tensor>("X");
  const int n = input->dims()[0];
  const int c = input->dims()[1];
  const int in_d = input->dims()[2];
  const int in_h = input->dims()[3];
  const int in_w = input->dims()[4];

  auto interp_method = ctx.Attr<std::string>("interp_method");
  bool align_corners = ctx.Attr<bool>("align_corners");
  int align_mode = ctx.Attr<int>("align_mode");

  int out_d = ctx.Attr<int>("out_d");
  int out_h = ctx.Attr<int>("out_h");
  int out_w = ctx.Attr<int>("out_w");
  float scale = ctx.Attr<float>("scale");
  if (scale > 0) {
    out_d = static_cast<int>(in_d * scale);
    out_h = static_cast<int>(in_h * scale);
    out_w = static_cast<int>(in_w * scale);
  }

  auto out_size = ctx.Input<Tensor>("OutSize");
  if (out_size != nullptr) {
    Tensor sizes;
    framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
    auto size_data = sizes.data<int>();
    out_d = size_data[0];
    out_h = size_data[1];
    out_w = size_data[2];
  }

  auto* output_grad_data = output_grad.data<T>();
  auto* input_grad_data =
      input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
  auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
  math::SetConstant<platform::CUDADeviceContext, T> zero;
  zero(device_ctx, input_grad, static_cast<T>(0.0));

  if (in_d == out_d && in_h == out_h && in_w == out_w) {
    framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
    return;
  }

  float ratio_d = 0.f;
  float ratio_h = 0.f;
  float ratio_w = 0.f;
  if (out_d > 1) {
    ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
                              : static_cast<float>(in_d) / out_d;
  }
  if (out_h > 1) {
    ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
                              : static_cast<float>(in_h) / out_h;
  }
  if (out_w > 1) {
    ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
                              : static_cast<float>(in_w) / out_w;
  }

  int in_dhw = in_d * in_h * in_w;
  int out_dhw = out_d * out_h * out_w;
  int in_cdhw = c * in_dhw;
  int out_cdhw = c * out_dhw;

  int pixelNum = n * out_cdhw;
  int grid_dim = (pixelNum + 512 - 1) / 512;
  grid_dim = grid_dim > 8 ? 8 : grid_dim;

  if ("trilinear" == interp_method) {
    KeTrilinearInterpBw<
        T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
        input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
        out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
        align_mode);
  }
}

template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");

    auto input_dims = input->dims();
    if (input_dims.size() == 4) {  // 2D interpolation
      Interpolate2DCUDAFwd<T>(ctx, *input, output);
    } else if (input_dims.size() == 5) {  // 3D interpolation
      Interpolate3DCUDAFwd<T>(ctx, *input, output);
671 672 673 674 675 676 677 678
    }
  }
};

template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
K
Kaipeng Deng 已提交
679 680
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
681 682
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
D
dengkaipeng 已提交
683

K
Kaipeng Deng 已提交
684 685 686 687 688
    auto output_grad_dims = output_grad->dims();
    if (output_grad_dims.size() == 4) {  // 2D interpolation
      Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
    } else if (output_grad_dims.size() == 5) {  // 3D interpolation
      Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
689 690 691 692 693 694 695 696
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
697
REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::InterpolateOpCUDAKernel<float>,
698 699
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
700 701 702 703 704 705 706
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
707 708
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);
K
Kaipeng Deng 已提交
709 710 711 712 713 714
REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
                        ops::InterpolateOpCUDAKernel<double>,
                        ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
                        ops::InterpolateGradOpCUDAKernel<float>,
                        ops::InterpolateGradOpCUDAKernel<double>);