pad3d_op.cu 32.2 KB
Newer Older
L
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
17 18
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
19
#include "paddle/phi/kernels/funcs/math_function.h"
L
littletomatodonkey 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;

using framework::Tensor;

template <typename T>
__global__ void Pad3DConstNCDHW(const int nthreads, const T* in_data,
                                const int num, const int channels,
                                const int in_depth, const int in_height,
                                const int in_width, const int out_depth,
                                const int out_height, const int out_width,
                                const int pad_front, const int pad_top,
                                const int pad_left, T value, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int nc = index / out_width;

    const int out_w = index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = out_d - pad_front;
    int in_h = out_h - pad_top;
    int in_w = out_w - pad_left;
    out_data[index] =
        (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth ||
         in_h >= in_height || in_w >= in_width)
            ? value
            : in_data[nc * in_depth * in_height * in_width +
                      in_d * in_height * in_width + in_h * in_width + in_w];
  }
}

template <typename T>
__global__ void Pad3DConstNDHWC(const int nthreads, const T* in_data,
                                const int num, const int channels,
                                const int in_depth, const int in_height,
                                const int in_width, const int out_depth,
                                const int out_height, const int out_width,
                                const int pad_front, const int pad_top,
                                const int pad_left, T value, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int n = index / channels;
    const int c = index % channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;
    const int in_d = out_d - pad_front;
    const int in_h = out_h - pad_top;
    const int in_w = out_w - pad_left;

    out_data[index] =
        (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth ||
         in_h >= in_height || in_w >= in_width)
            ? value
            : in_data[n * in_depth * in_height * in_width * channels +
                      in_d * in_height * in_width * channels +
                      in_h * in_width * channels + in_w * channels + c];
  }
}

template <typename T>
__global__ void Pad3DReflectNCDHW(const int nthreads, const T* in_data,
                                  const int num, const int channels,
                                  const int in_depth, const int in_height,
                                  const int in_width, const int out_depth,
                                  const int out_height, const int out_width,
                                  const int pad_front, const int pad_top,
                                  const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int nc = index / out_width;

    const int out_w = index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = out_d - pad_front;
    int in_h = out_h - pad_top;
    int in_w = out_w - pad_left;

    in_d = max(in_d, -in_d);                     // reflect by 0
    in_d = min(in_d, 2 * in_depth - in_d - 2);   // reflect by in_depth
    in_h = max(in_h, -in_h);                     // reflect by 0
    in_h = min(in_h, 2 * in_height - in_h - 2);  // reflect by in_height
    in_w = max(in_w, -in_w);                     // reflect by 0
    in_w = min(in_w, 2 * in_width - in_w - 2);   // reflect by in_width
    out_data[index] =
        in_data[(nc * in_depth * in_height + in_d * in_height + in_h) *
                    in_width +
                in_w];
  }
}

template <typename T>
__global__ void Pad3DReflectNDHWC(const int nthreads, const T* in_data,
                                  const int num, const int channels,
                                  const int in_depth, const int in_height,
                                  const int in_width, const int out_depth,
                                  const int out_height, const int out_width,
                                  const int pad_front, const int pad_top,
                                  const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int n = index / channels;
    const int c = index % channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;
    int in_d = out_d - pad_front;
    int in_h = out_h - pad_top;
    int in_w = out_w - pad_left;

    in_d = max(in_d, -in_d);
    in_d = min(in_d, 2 * in_depth - in_d - 2);
    in_h = max(in_h, -in_h);
    in_h = min(in_h, 2 * in_height - in_h - 2);
    in_w = max(in_w, -in_w);
    in_w = min(in_w, 2 * in_width - in_w - 2);

    out_data[index] = in_data[n * in_depth * in_height * in_width * channels +
                              in_d * in_height * in_width * channels +
                              in_h * in_width * channels + in_w * channels + c];
  }
}

template <typename T>
__global__ void Pad3DReplicateNCDHW(const int nthreads, const T* in_data,
                                    const int num, const int channels,
                                    const int in_depth, const int in_height,
                                    const int in_width, const int out_depth,
                                    const int out_height, const int out_width,
                                    const int pad_front, const int pad_top,
                                    const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int nc = index / out_width;

    const int out_w = index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = min(in_depth - 1, max(out_d - pad_front, 0));
    int in_h = min(in_height - 1, max(out_h - pad_top, 0));
    int in_w = min(in_width - 1, max(out_w - pad_left, 0));

    out_data[index] =
        in_data[(nc * in_depth * in_height + in_d * in_height + in_h) *
                    in_width +
                in_w];
  }
}

template <typename T>
__global__ void Pad3DReplicateNDHWC(const int nthreads, const T* in_data,
                                    const int num, const int channels,
                                    const int in_depth, const int in_height,
                                    const int in_width, const int out_depth,
                                    const int out_height, const int out_width,
                                    const int pad_front, const int pad_top,
                                    const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int n = index / channels;
    const int c = index % channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;

    int in_d = min(in_depth - 1, max(out_d - pad_front, 0));
    int in_h = min(in_height - 1, max(out_h - pad_top, 0));
    int in_w = min(in_width - 1, max(out_w - pad_left, 0));

    out_data[index] = in_data[n * in_depth * in_height * in_width * channels +
                              in_d * in_height * in_width * channels +
                              in_h * in_width * channels + in_w * channels + c];
  }
}

template <typename T>
__global__ void Pad3DCircularNCDHW(const int nthreads, const T* in_data,
                                   const int num, const int channels,
                                   const int in_depth, const int in_height,
                                   const int in_width, const int out_depth,
                                   const int out_height, const int out_width,
                                   const int pad_front, const int pad_top,
                                   const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int nc = index / out_width;

    const int out_w = index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
    int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
    int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;

    out_data[index] =
        in_data[(nc * in_depth * in_height + in_d * in_height + in_h) *
                    in_width +
                in_w];
  }
}

template <typename T>
__global__ void Pad3DCircularNDHWC(const int nthreads, const T* in_data,
                                   const int num, const int channels,
                                   const int in_depth, const int in_height,
                                   const int in_width, const int out_depth,
                                   const int out_height, const int out_width,
                                   const int pad_front, const int pad_top,
                                   const int pad_left, T* out_data) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    int n = index / channels;
    const int c = index % channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;

    int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
    int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
    int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;

    out_data[index] = in_data[n * in_depth * in_height * in_width * channels +
                              in_d * in_height * in_width * channels +
                              in_h * in_width * channels + in_w * channels + c];
  }
}

template <typename T>
__global__ void Pad3DGradConstNCDHW(const int in_size, T* d_in_data,
                                    const int num, const int channels,
                                    const int in_depth, const int in_height,
                                    const int in_width, const int out_depth,
                                    const int out_height, const int out_width,
                                    const int pad_front, const int pad_top,
                                    const int pad_left, const T* d_out_data) {
  CUDA_KERNEL_LOOP(in_index, in_size) {
    const int in_w = in_index % in_width;

    int nc = in_index / in_width;
    const int in_h = nc % in_height;

    nc /= in_height;
    const int in_d = nc % in_depth;

    nc /= in_depth;

    const int out_d = in_d + pad_front;
    const int out_h = in_h + pad_top;
    const int out_w = in_w + pad_left;
    d_in_data[in_index] =
        d_out_data[nc * out_depth * out_height * out_width +
                   out_d * out_height * out_width + out_h * out_width + out_w];
  }
}

template <typename T>
__global__ void Pad3DGradConstNDHWC(const int in_size, T* d_in_data,
                                    const int num, const int channels,
                                    const int in_depth, const int in_height,
                                    const int in_width, const int out_depth,
                                    const int out_height, const int out_width,
                                    const int pad_front, const int pad_top,
                                    const int pad_left, const T* d_out_data) {
  CUDA_KERNEL_LOOP(in_index, in_size) {
    const int c = in_index % channels;
    int n = in_index / channels;

    const int in_w = n % in_width;
    n /= in_width;

    const int in_h = n % in_height;
    n /= in_height;

    const int in_d = n % in_depth;
    n /= in_depth;

    const int out_d = in_d + pad_front;
    const int out_h = in_h + pad_top;
    const int out_w = in_w + pad_left;

    d_in_data[in_index] =
        d_out_data[n * out_depth * out_height * out_width * channels +
                   out_d * out_height * out_width * channels +
                   out_h * out_width * channels + out_w * channels + c];
  }
}

template <typename T>
__global__ void Pad3DGradReflectNCDHW(const int out_size, T* d_in_data,
                                      const int num, const int channels,
                                      const int in_depth, const int in_height,
                                      const int in_width, const int out_depth,
                                      const int out_height, const int out_width,
                                      const int pad_front, const int pad_top,
                                      const int pad_left, const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    int nc = out_index / out_width;
    const int out_w = out_index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = out_d - pad_front;
    int in_h = out_h - pad_top;
    int in_w = out_w - pad_left;

    in_d = max(in_d, -in_d);
    in_h = max(in_h, -in_h);
    in_w = max(in_w, -in_w);

    in_d = min(in_d, 2 * in_depth - in_d - 2);
    in_h = min(in_h, 2 * in_height - in_h - 2);
    in_w = min(in_w, 2 * in_width - in_w - 2);

    platform::CudaAtomicAdd(
        &d_in_data[nc * in_depth * in_height * in_width +
                   in_d * in_height * in_width + in_h * in_width + in_w],
        d_out_data[out_index]);
  }
}

template <typename T>
__global__ void Pad3DGradReflectNDHWC(const int out_size, T* d_in_data,
                                      const int num, const int channels,
                                      const int in_depth, const int in_height,
                                      const int in_width, const int out_depth,
                                      const int out_height, const int out_width,
                                      const int pad_front, const int pad_top,
                                      const int pad_left, const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    const int c = out_index % channels;
    int n = out_index / channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;

    int in_d = out_d - pad_front;
    int in_h = out_h - pad_top;
    int in_w = out_w - pad_left;

    in_d = max(in_d, -in_d);
    in_h = max(in_h, -in_h);
    in_w = max(in_w, -in_w);

    in_d = min(in_d, in_depth * 2 - in_d - 2);
    in_h = min(in_h, in_height * 2 - in_h - 2);
    in_w = min(in_w, in_width * 2 - in_w - 2);
    platform::CudaAtomicAdd(
        &d_in_data[n * in_depth * in_height * in_width * channels +
                   in_d * in_height * in_width * channels +
                   in_h * in_width * channels + in_w * channels + c],
        d_out_data[out_index]);
  }
}

template <typename T>
__global__ void Pad3DGradReplicateNCDHW(
    const int out_size, T* d_in_data, const int num, const int channels,
    const int in_depth, const int in_height, const int in_width,
    const int out_depth, const int out_height, const int out_width,
    const int pad_front, const int pad_top, const int pad_left,
    const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    int nc = out_index / out_width;
    const int out_w = out_index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    const int in_d = min(in_depth - 1, max(out_d - pad_front, 0));
    const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
    const int in_w = min(in_width - 1, max(out_w - pad_left, 0));

    platform::CudaAtomicAdd(
        &d_in_data[nc * in_depth * in_height * in_width +
                   in_d * in_height * in_width + in_h * in_width + in_w],
        d_out_data[out_index]);
  }
}

template <typename T>
__global__ void Pad3DGradReplicateNDHWC(
    const int out_size, T* d_in_data, const int num, const int channels,
    const int in_depth, const int in_height, const int in_width,
    const int out_depth, const int out_height, const int out_width,
    const int pad_front, const int pad_top, const int pad_left,
    const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    const int c = out_index % channels;
    int n = out_index / channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;

    const int in_d = min(in_depth - 1, max(out_d - pad_front, 0));
    const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
    const int in_w = min(in_width - 1, max(out_w - pad_left, 0));

    platform::CudaAtomicAdd(
        &d_in_data[n * in_depth * in_height * in_width * channels +
                   in_d * in_height * in_width * channels +
                   in_h * in_width * channels + in_w * channels + c],
        d_out_data[out_index]);
  }
}

template <typename T>
__global__ void Pad3DGradCircularNCDHW(const int out_size, T* d_in_data,
                                       const int num, const int channels,
                                       const int in_depth, const int in_height,
                                       const int in_width, const int out_depth,
                                       const int out_height,
                                       const int out_width, const int pad_front,
                                       const int pad_top, const int pad_left,
                                       const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    int nc = out_index / out_width;
    const int out_w = out_index % out_width;
    const int out_h = nc % out_height;
    nc /= out_height;
    const int out_d = nc % out_depth;
    nc /= out_depth;

    int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
    int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
    int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;

    platform::CudaAtomicAdd(
        &d_in_data[nc * in_depth * in_height * in_width +
                   in_d * in_height * in_width + in_h * in_width + in_w],
        d_out_data[out_index]);
  }
}

template <typename T>
__global__ void Pad3DGradCircularNDHWC(const int out_size, T* d_in_data,
                                       const int num, const int channels,
                                       const int in_depth, const int in_height,
                                       const int in_width, const int out_depth,
                                       const int out_height,
                                       const int out_width, const int pad_front,
                                       const int pad_top, const int pad_left,
                                       const T* d_out_data) {
  CUDA_KERNEL_LOOP(out_index, out_size) {
    const int c = out_index % channels;
    int n = out_index / channels;
    const int out_w = n % out_width;
    n /= out_width;
    const int out_h = n % out_height;
    n /= out_height;
    const int out_d = n % out_depth;
    n /= out_depth;

    int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
    int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
    int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;

    platform::CudaAtomicAdd(
        &d_in_data[n * in_depth * in_height * in_width * channels +
                   in_d * in_height * in_width * channels +
                   in_h * in_width * channels + in_w * channels + c],
        d_out_data[out_index]);
  }
}

static inline std::vector<int> GetPaddings(
    const framework::ExecutionContext& context) {
  std::vector<int> paddings(6);
  auto* paddings_data = context.Input<Tensor>("Paddings");
  if (paddings_data) {
    Tensor pads;
    framework::TensorCopySync(*paddings_data, platform::CPUPlace(), &pads);
    auto pads_data = pads.data<int>();
    std::memcpy(paddings.data(), pads_data, paddings.size() * sizeof(int));
  } else {
    auto pads = context.Attr<std::vector<int>>("paddings");
    std::copy(pads.begin(), pads.end(), paddings.data());
  }
  return paddings;
}

template <typename T>
class Pad3dCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    std::vector<int> pads = GetPaddings(context);
    auto mode = context.Attr<std::string>("mode");
    auto data_format = context.Attr<std::string>("data_format");
    T value = static_cast<T>(context.Attr<float>("value"));

    auto* x = context.Input<Tensor>("X");
    auto in_dims = x->dims();
    const T* in_data = x->data<T>();
    auto* out = context.Output<Tensor>("Out");
    auto out_dims = out->dims();
    if (data_format == "NCDHW") {
      out_dims[0] = in_dims[0];
      out_dims[1] = in_dims[1];
      out_dims[2] = in_dims[2] + pads[4] + pads[5];
      out_dims[3] = in_dims[3] + pads[2] + pads[3];
      out_dims[4] = in_dims[4] + pads[0] + pads[1];
    } else {
      out_dims[0] = in_dims[0];
      out_dims[1] = in_dims[1] + pads[4] + pads[5];
      out_dims[2] = in_dims[2] + pads[2] + pads[3];
      out_dims[3] = in_dims[3] + pads[0] + pads[1];
      out_dims[4] = in_dims[4];
    }
    T* out_data = out->mutable_data<T>(out_dims, context.GetPlace());

    int channels = in_dims[1];
    int in_depth = in_dims[2];
    int in_height = in_dims[3];
    int in_width = in_dims[4];
    int out_depth = out_dims[2];
    int out_height = out_dims[3];
    int out_width = out_dims[4];
    if (data_format == "NDHWC") {
      channels = in_dims[4];
      in_depth = in_dims[1];
      in_height = in_dims[2];
      in_width = in_dims[3];
      out_depth = out_dims[1];
      out_height = out_dims[2];
      out_width = out_dims[3];
    }

    if (mode == "reflect") {
      PADDLE_ENFORCE_GT(in_depth, pads[4],
                        platform::errors::InvalidArgument(
                            "The depth of Input(X)'s dimension should be "
                            "greater than pad_front"
                            " in reflect mode"
                            ", but received depth(%d) and pad_front(%d).",
                            in_depth, pads[4]));
      PADDLE_ENFORCE_GT(in_depth, pads[5],
                        platform::errors::InvalidArgument(
                            "The depth of Input(X)'s dimension should be "
                            "greater than pad_back"
                            " in reflect mode"
                            ", but received depth(%d) and pad_back(%d).",
                            in_depth, pads[5]));

      PADDLE_ENFORCE_GT(in_height, pads[2],
                        platform::errors::InvalidArgument(
                            "The height of Input(X)'s dimension should be "
                            "greater than pad_top"
                            " in reflect mode"
                            ", but received depth(%d) and pad_top(%d).",
                            in_height, pads[2]));
      PADDLE_ENFORCE_GT(in_height, pads[3],
                        platform::errors::InvalidArgument(
                            "The height of Input(X)'s dimension should be "
                            "greater than pad_bottom"
                            " in reflect mode"
                            ", but received depth(%d) and pad_bottom(%d).",
                            in_height, pads[3]));

      PADDLE_ENFORCE_GT(in_width, pads[0],
                        platform::errors::InvalidArgument(
                            "The width of Input(X)'s dimension should be "
                            "greater than pad_left"
                            " in reflect mode"
                            ", but received depth(%d) and pad_left(%d).",
                            in_width, pads[0]));
      PADDLE_ENFORCE_GT(in_width, pads[1],
                        platform::errors::InvalidArgument(
                            "The width of Input(X)'s dimension should be "
                            "greater than pad_right"
                            " in reflect mode"
                            ", but received depth(%d) and pad_right(%d).",
                            in_width, pads[1]));
621 622 623 624 625
    } else if (mode == "circular" || mode == "replicate") {
      PADDLE_ENFORCE_NE(in_depth * in_height * in_width, 0,
                        platform::errors::InvalidArgument(
                            "The input tensor size can not be 0 for circular "
                            "or replicate padding mode."));
626 627
    }

L
littletomatodonkey 已提交
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
    const int pad_left = pads[0];
    const int pad_top = pads[2];
    const int pad_front = pads[4];
    const int num = in_dims[0];

    auto stream = context.cuda_device_context().stream();
    int block = PADDLE_CUDA_NUM_THREADS;
    const int out_size = out->numel();
    int grid = (out_size + block - 1) / block;

    if (data_format == "NCDHW") {
      if (mode == "reflect") {
        Pad3DReflectNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else if (mode == "replicate") {
        Pad3DReplicateNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else if (mode == "circular") {
        Pad3DCircularNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else {
        Pad3DConstNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            value, out_data);
      }
    } else {
      if (mode == "reflect") {
        Pad3DReflectNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else if (mode == "replicate") {
        Pad3DReplicateNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else if (mode == "circular") {
        Pad3DCircularNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            out_data);
      } else {
        Pad3DConstNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            value, out_data);
      }
    }
  }
};

template <typename T>
class Pad3dGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    std::vector<int> pads = GetPaddings(context);
    auto mode = context.Attr<std::string>("mode");
    auto data_format = context.Attr<std::string>("data_format");
    auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* d_in = context.Output<Tensor>(framework::GradVarName("X"));
    auto d_in_dims = d_in->dims();
    auto d_out_dims = d_out->dims();
    const T* d_out_data = d_out->data<T>();
    T* d_in_data = d_in->mutable_data<T>(context.GetPlace());

700
    phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
L
littletomatodonkey 已提交
701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
    set_zero(context.template device_context<platform::CUDADeviceContext>(),
             d_in, static_cast<T>(0));

    const int pad_left = pads[0];
    const int pad_top = pads[2];
    const int pad_front = pads[4];

    const int num = d_in_dims[0];

    auto stream = context.cuda_device_context().stream();
    int block = PADDLE_CUDA_NUM_THREADS;
    const int out_size = d_out->numel();
    const int in_size = d_in->numel();
    int grid = (out_size + block - 1) / block;

    if (data_format == "NCDHW") {
      const int channels = d_in_dims[1];
      const int in_depth = d_in_dims[2];
      const int in_height = d_in_dims[3];
      const int in_width = d_in_dims[4];
      const int out_depth = d_out_dims[2];
      const int out_height = d_out_dims[3];
      const int out_width = d_out_dims[4];

      if (mode == "reflect") {
        Pad3DGradReflectNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else if (mode == "replicate") {
        Pad3DGradReplicateNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else if (mode == "circular") {
        Pad3DGradCircularNCDHW<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else {
        grid = (in_size + block - 1) / block;
        Pad3DGradConstNCDHW<T><<<grid, block, 0, stream>>>(
            in_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      }
    } else {
      const int channels = d_in_dims[4];
      const int in_depth = d_in_dims[1];
      const int in_height = d_in_dims[2];
      const int in_width = d_in_dims[3];
      const int out_depth = d_out_dims[1];
      const int out_height = d_out_dims[2];
      const int out_width = d_out_dims[3];
      if (mode == "reflect") {
        Pad3DGradReflectNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else if (mode == "replicate") {
        Pad3DGradReplicateNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else if (mode == "circular") {
        Pad3DGradCircularNDHWC<T><<<grid, block, 0, stream>>>(
            out_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      } else {
        grid = (in_size + block - 1) / block;
        Pad3DGradConstNDHWC<T><<<grid, block, 0, stream>>>(
            in_size, d_in_data, num, channels, in_depth, in_height, in_width,
            out_depth, out_height, out_width, pad_front, pad_top, pad_left,
            d_out_data);
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(pad3d, ops::Pad3dCUDAKernel<plat::float16>,
                        ops::Pad3dCUDAKernel<float>,
                        ops::Pad3dCUDAKernel<double>, ops::Pad3dCUDAKernel<int>,
                        ops::Pad3dCUDAKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(pad3d_grad, ops::Pad3dGradCUDAKernel<plat::float16>,
                        ops::Pad3dGradCUDAKernel<float>,
                        ops::Pad3dGradCUDAKernel<double>);