conv_grad_grad_kernel.cu 27.8 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/conv_grad_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/fluid/framework/eigen.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/conv_miopen_helper.h"
#else
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif

#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/kernels/funcs/padding.h"

#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"

#include "paddle/phi/kernels/impl/conv_cudnn_impl.h"

#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void ConvCudnnGradGradKernel(
    const Context& ctx,
    paddle::optional<const DenseTensor&> input_grad_grad,
    paddle::optional<const DenseTensor&> filter_grad_grad,
    const DenseTensor& out_grad,
    const DenseTensor& input,
    const DenseTensor& filter,
    const std::vector<int>& strides,
    const std::vector<int>& paddings_t,
    const std::string& padding_algorithm,
    int groups,
    const std::vector<int>& dilations_t,
    const std::string& data_format,
    bool use_addto,
    int workspace_size_MB,
    bool exhaustive_search_t,
    DenseTensor* out_grad_grad,
    DenseTensor* input_grad,
    DenseTensor* filter_grad) {
  auto X = &input;
  auto W = &filter;
  auto dO = &out_grad;
  auto ddX = input_grad_grad.get_ptr();
  auto ddW = filter_grad_grad.get_ptr();

  auto ddO = out_grad_grad;
  auto dW = filter_grad;
  auto dX = input_grad;
  if (ddO) {
    ddO->mutable_data<T>(ctx.GetPlace());
    phi::funcs::SetConstant<Context, T> set_zero;
    set_zero(ctx, ddO, static_cast<T>(0));
  }
  if (dW) {
    dW->mutable_data<T>(ctx.GetPlace());
  }
  if (dX) {
    dX->mutable_data<T>(ctx.GetPlace());
  }

  // const T* x = X->data<T>();
  const T* dy = dO->data<T>();
  const T* w = W->data<T>();

  const T* ddx = nullptr;
  const T* ddw = nullptr;
  T *dw, *dx, *ddy;
  dw = dx = ddy = nullptr;
  T* transformed_dx = nullptr;
  std::vector<int> dilations = dilations_t;

  bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
  bool deterministic = FLAGS_cudnn_deterministic;
  auto exhaustive_deterministic = exhaustive_search && deterministic;
  PADDLE_ENFORCE_EQ(exhaustive_deterministic,
                    false,
                    phi::errors::InvalidArgument(
                        "Cann't set exhaustive_search True and "
                        "FLAGS_cudnn_deterministic True at same time."));

  std::vector<int> paddings = paddings_t;

  const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

  // transform Tensors to channel first-----------
  DenseTensor transformed_X_channel(X->type());
  DenseTensor transformed_dO_channel(dO->type());
  DenseTensor transformed_ddX_channel(X->type());

  DenseTensor transformed_ddO_channel(dO->type());
  DenseTensor transformed_dX_channel(X->type());

  if (channel_last) {
    ResizeToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);
    TransToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);

    ResizeToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);
    TransToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);

    if (ddX) {
      ResizeToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
      TransToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
    }

    if (ddO) {
      ResizeToChannelFirst<Context, T>(ctx, ddO, &transformed_ddO_channel);
    }
    if (dX) {
      ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel);
      transformed_dX_channel.mutable_data<T>(ctx.GetPlace());
    }

  } else {
    transformed_X_channel = *X;
    transformed_dO_channel = *dO;
    if (ddX) {
      transformed_ddX_channel = *ddX;
    }
    if (ddO) {
      transformed_ddO_channel.ShareDataWith(*ddO);
    }
    if (dX) {
      transformed_dX_channel.ShareDataWith(*dX);
    }
  }

  auto in_dims = transformed_X_channel.dims();
  auto filter_dims = W->dims();
  DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
  DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
  std::vector<int> ksize = vectorize<int>(filter_data_dims);
  UpdatePaddingAndDilation(
      &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);

  int data_dim = strides.size();  // 2d or 3d
  bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim);
  DenseTensor transformed_X(X->type());
  DenseTensor transformed_ddX(X->type());

  DenseTensor transformed_dX(X->type());

  std::vector<int> padding_common(data_dim, 0);
  std::vector<int> input_pad(X->dims().size() * 2, 0);

  if (!is_sys_pad) {
    // get pad
    std::vector<int> padding_diff(data_dim);
    std::vector<int> new_input_shape_vec(data_dim + 2);
    new_input_shape_vec[0] = transformed_X_channel.dims()[0];
    new_input_shape_vec[1] = transformed_X_channel.dims()[1];

    for (size_t i = 0; i < data_dim; ++i) {
      padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
      padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
      new_input_shape_vec[i + 2] =
          transformed_X_channel.dims()[i + 2] + padding_diff[i];
      input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
      input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
    }
    DDim new_input_shape(make_ddim(new_input_shape_vec));
    transformed_X.Resize(new_input_shape);
    transformed_ddX.Resize(new_input_shape);
    transformed_dX.Resize(new_input_shape);

    transformed_X.mutable_data<T>(ctx.GetPlace());

    if (ddX) {
      transformed_ddX.mutable_data<T>(ctx.GetPlace());
    }
    if (dX) {
      transformed_dX.mutable_data<T>(ctx.GetPlace());
    }

    // pad for input
    const int rank = X->dims().size();
    T pad_value(0.0);
    switch (rank) {
      case 4: {
        funcs::PadFunction<Context, T, 4>(
            ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
        if (ddX) {
          funcs::PadFunction<Context, T, 4>(ctx,
                                            input_pad,
                                            transformed_ddX_channel,
                                            pad_value,
                                            &transformed_ddX);
        }
      } break;
      case 5: {
        funcs::PadFunction<Context, T, 5>(
            ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
        if (ddX) {
          funcs::PadFunction<Context, T, 5>(ctx,
                                            input_pad,
                                            transformed_ddX_channel,
                                            pad_value,
                                            &transformed_ddX);
        }
      } break;
      default:
        PADDLE_THROW(phi::errors::InvalidArgument(
            "ConvOp only support tensors with 4 or 5 dimensions."));
    }

  } else {
    transformed_X.ShareDataWith(transformed_X_channel);
    if (ddX) {
      transformed_ddX.ShareDataWith(transformed_ddX_channel);
    }
    if (dX) {
      transformed_dX.ShareDataWith(transformed_dX_channel);
    }

    if (paddings.size() == data_dim) {
      for (size_t i = 0; i < data_dim; ++i) {
        padding_common[i] = paddings[i];
      }
    } else {
      for (size_t i = 0; i < data_dim; ++i) {
        padding_common[i] = paddings[2 * i];
      }
    }
  }

  const T* x = transformed_X.data<T>();

  int iwo_group = groups;
  int c_group = 1;
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
  iwo_group = 1;
  c_group = groups;
  groups = 1;
#endif
  auto dtype = paddle::platform::CudnnDataType<T>::type;

  auto handle = ctx.cudnn_handle();

  paddle::operators::ConvArgs args1{&transformed_ddX,
                                    W,
                                    &transformed_ddO_channel,
                                    strides,
                                    padding_common,
                                    dilations,
                                    dtype};
  paddle::operators::ConvArgs args2{&transformed_X,
                                    ddW,
                                    &transformed_ddO_channel,
                                    strides,
                                    padding_common,
                                    dilations,
                                    dtype};
  paddle::operators::ConvArgs args3{&transformed_ddX,
                                    dW,
                                    &transformed_dO_channel,
                                    strides,
                                    padding_common,
                                    dilations,
                                    dtype};
  paddle::operators::ConvArgs args4{&transformed_dX,
                                    ddW,
                                    &transformed_dO_channel,
                                    strides,
                                    padding_common,
                                    dilations,
                                    dtype};

#ifdef PADDLE_WITH_HIP
  miopenConvFwdAlgorithm_t fwd_algo1 = static_cast<miopenConvFwdAlgorithm_t>(0);
  miopenConvFwdAlgorithm_t fwd_algo2 = static_cast<miopenConvFwdAlgorithm_t>(0);
  miopenConvBwdDataAlgorithm_t data_algo =
      static_cast<miopenConvBwdDataAlgorithm_t>(0);
  miopenConvBwdWeightsAlgorithm_t filter_algo =
      static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
#else
  cudnnConvolutionFwdAlgo_t fwd_algo1 =
      static_cast<cudnnConvolutionFwdAlgo_t>(0);
  cudnnConvolutionFwdAlgo_t fwd_algo2 =
      static_cast<cudnnConvolutionFwdAlgo_t>(0);
  cudnnConvolutionBwdDataAlgo_t data_algo =
      static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
  cudnnConvolutionBwdFilterAlgo_t filter_algo =
      static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
#endif

  auto layout = paddle::platform::GetCudnnTensorFormat(
      paddle::platform::DataLayout::kNCHW);

  // ddo = conv(ddI, W) + conv(I, ddW)
  size_t workspace_size = 0;

  T* transformed_ddy_channel = nullptr;
  if (ddO) {
    ddy = ddO->data<T>();
    transformed_ddy_channel = transformed_ddO_channel.data<T>();
    if (ddX) {
      args1.handle = handle;
      args1.idesc.set(transformed_ddX, iwo_group);
      args1.wdesc.set(*W, layout, iwo_group);
      args1.odesc.set(transformed_ddO_channel, iwo_group);
      args1.cdesc.set(dtype,
                      padding_common,
                      strides,
                      dilations,
                      paddle::platform::AllowTF32Cudnn(),
                      c_group);

#ifdef PADDLE_WITH_HIP
      using search1 =
          paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
      workspace_size = search1::GetWorkspaceSize(args1);
      fwd_algo1 = search1::Find<T>(
          args1, exhaustive_search, false, workspace_size, ctx);
#else
      using search1 =
          paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
      fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
      workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
#endif
    }

    if (ddW) {
      ddw = ddW->data<T>();
      args2.handle = handle;
      args2.idesc.set(transformed_X, iwo_group);
      args2.wdesc.set(*ddW, layout, iwo_group);
      args2.odesc.set(transformed_ddO_channel, iwo_group);
      args2.cdesc.set(dtype,
                      padding_common,
                      strides,
                      dilations,
                      paddle::platform::AllowTF32Cudnn(),
                      c_group);

#ifdef PADDLE_WITH_HIP
      using search2 =
          paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
      workspace_size =
          std::max(workspace_size, search2::GetWorkspaceSize(args2));
      fwd_algo2 = search2::Find<T>(
          args2, exhaustive_search, false, workspace_size, ctx);
#else
      using search2 =
          paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
      fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
      workspace_size =
          std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2));
#endif
    }
  }

  if (dW && ddX) {
    dw = dW->data<T>();
    args3.handle = handle;
    args3.idesc.set(transformed_ddX, iwo_group);
    args3.wdesc.set(*dW, layout, iwo_group);
    args3.odesc.set(transformed_dO_channel, iwo_group);
    args3.cdesc.set(dtype,
                    padding_common,
                    strides,
                    dilations,
                    paddle::platform::AllowTF32Cudnn(),
                    c_group);

#ifdef PADDLE_WITH_HIP
    using search3 =
        paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
    workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
    filter_algo = search3::Find<T>(
        args3, exhaustive_search, deterministic, workspace_size, ctx);
#else
    using search3 =
        paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
    filter_algo =
        search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
    workspace_size =
        std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo));
#endif
  }

  if (ddW && dX) {
    transformed_dx = transformed_dX.data<T>();

    args4.handle = handle;
    args4.idesc.set(transformed_dX, iwo_group);
    args4.wdesc.set(*ddW, layout, iwo_group);
    args4.odesc.set(transformed_dO_channel, iwo_group);
    args4.cdesc.set(dtype,
                    padding_common,
                    strides,
                    dilations,
                    paddle::platform::AllowTF32Cudnn(),
                    c_group);

#ifdef PADDLE_WITH_HIP
    using search4 =
        paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
    workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
    data_algo = search4::Find<T>(
        args4, exhaustive_search, deterministic, workspace_size, ctx);
#else
    using search4 =
        paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
    data_algo = search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
    workspace_size =
        std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
#endif
  }

  int i_n, i_c, i_d, i_h, i_w;
  GetNCDHW(
      transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);

  int o_n, o_c, o_d, o_h, o_w;
  GetNCDHW(transformed_dO_channel.dims(),
           DataLayout::kNCHW,
           &o_n,
           &o_c,
           &o_d,
           &o_h,
           &o_w);

  int group_offset_in = i_c / groups * i_h * i_w * i_d;
  int group_offset_out = o_c / groups * o_h * o_w * o_d;
  int group_offset_filter = W->numel() / groups;

  paddle::operators::ScalingParamType<T> alpha = 1.0f;
  paddle::operators::ScalingParamType<T> beta = 0.0f;

  // NOTE(zhiqiu): inplace addto is not supportted in double grad yet.
  // ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f :
  // 0.0f;
  // VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr<bool>("use_addto");
  auto wkspace_handle = ctx.cudnn_workspace_handle();

  if (ddO) {
    if (ddX) {
      ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
      wkspace_handle.RunFunc(
          [&](void* workspace_ptr) {
            PADDLE_ENFORCE_GPU_SUCCESS(
                paddle::platform::dynload::miopenConvolutionForward(
                    handle,
                    &alpha,
                    args1.idesc.desc(),
                    ddx,
                    args1.wdesc.desc(),
                    w,
                    args1.cdesc.desc(),
                    fwd_algo1,
                    &beta,
                    args1.odesc.desc(),
                    transformed_ddy_channel,
                    workspace_ptr,
                    workspace_size));
          },
          workspace_size);
#else
      for (int i = 0; i < groups; i++) {
        wkspace_handle.RunFunc(
            [&](void* workspace_ptr) {
              PADDLE_ENFORCE_GPU_SUCCESS(
                  paddle::platform::dynload::cudnnConvolutionForward(
                      handle,
                      &alpha,
                      args1.idesc.desc(),
                      ddx + i * group_offset_in,
                      args1.wdesc.desc(),
                      w + i * group_offset_filter,
                      args1.cdesc.desc(),
                      fwd_algo1,
                      workspace_ptr,
                      workspace_size,
                      &beta,
                      args1.odesc.desc(),
                      transformed_ddy_channel + i * group_offset_out));
            },
            workspace_size);
      }
#endif
    }
    if (ddW) {
#ifdef PADDLE_WITH_HIP
      // MIOPEN ONLY support beta to be 0.0f
      wkspace_handle.RunFunc(
          [&](void* workspace_ptr) {
            PADDLE_ENFORCE_GPU_SUCCESS(
                paddle::platform::dynload::miopenConvolutionForward(
                    handle,
                    &alpha,
                    args2.idesc.desc(),
                    x,
                    args2.wdesc.desc(),
                    ddw,
                    args2.cdesc.desc(),
                    fwd_algo2,
                    &beta,
                    args2.odesc.desc(),
                    transformed_ddy_channel,
                    workspace_ptr,
                    workspace_size));
          },
          workspace_size);
#else
      for (int i = 0; i < groups; i++) {
        wkspace_handle.RunFunc(
            [&](void* workspace_ptr) {
              PADDLE_ENFORCE_GPU_SUCCESS(
                  paddle::platform::dynload::cudnnConvolutionForward(
                      handle,
                      &alpha,
                      args2.idesc.desc(),
                      x + i * group_offset_in,
                      args2.wdesc.desc(),
                      ddw + i * group_offset_filter,
                      args2.cdesc.desc(),
                      fwd_algo2,
                      workspace_ptr,
                      workspace_size,
                      &alpha,
                      args2.odesc.desc(),
                      transformed_ddy_channel + i * group_offset_out));
            },
            workspace_size);
      }
#endif
    }
    if (channel_last) {
      TransToChannelLast<Context, T>(ctx, &transformed_ddO_channel, ddO);
    }
  }
  T* transformed_dy_channel = transformed_dO_channel.data<T>();
  if (dW && ddX) {
    ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
    wkspace_handle.RunFunc(
        [&](void* workspace_ptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              paddle::platform::dynload::miopenConvolutionBackwardWeights(
                  handle,
                  &alpha,
                  args3.odesc.desc(),
                  transformed_dy_channel,
                  args3.idesc.desc(),
                  ddx,
                  args3.cdesc.desc(),
                  filter_algo,
                  &beta,
                  args3.wdesc.desc(),
                  dw,
                  workspace_ptr,
                  workspace_size));
        },
        workspace_size);
#else
    for (int i = 0; i < groups; i++) {
      wkspace_handle.RunFunc(
          [&](void* workspace_ptr) {
            PADDLE_ENFORCE_GPU_SUCCESS(
                paddle::platform::dynload::cudnnConvolutionBackwardFilter(
                    handle,
                    &alpha,
                    args3.idesc.desc(),
                    ddx + i * group_offset_in,
                    args3.odesc.desc(),
                    transformed_dy_channel + i * group_offset_out,
                    args3.cdesc.desc(),
                    filter_algo,
                    workspace_ptr,
                    workspace_size,
                    &beta,
                    args3.wdesc.desc(),
                    dw + i * group_offset_filter));
          },
          workspace_size);
    }
#endif
  }

  if (dX && ddW) {
    ddw = ddW->data<T>();
#ifdef PADDLE_WITH_HIP
    wkspace_handle.RunFunc(
        [&](void* workspace_ptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              paddle::platform::dynload::miopenConvolutionBackwardData(
                  handle,
                  &alpha,
                  args4.odesc.desc(),
                  transformed_dy_channel,
                  args4.wdesc.desc(),
                  ddw,
                  args4.cdesc.desc(),
                  data_algo,
                  &beta,
                  args4.idesc.desc(),
                  transformed_dx,
                  workspace_ptr,
                  workspace_size));
        },
        workspace_size);
#else
    for (int i = 0; i < groups; i++) {
      wkspace_handle.RunFunc(
          [&](void* workspace_ptr) {
            PADDLE_ENFORCE_GPU_SUCCESS(
                paddle::platform::dynload::cudnnConvolutionBackwardData(
                    handle,
                    &alpha,
                    args4.wdesc.desc(),
                    ddw + i * group_offset_filter,
                    args4.odesc.desc(),
                    transformed_dy_channel + i * group_offset_out,
                    args4.cdesc.desc(),
                    data_algo,
                    workspace_ptr,
                    workspace_size,
                    &beta,
                    args4.idesc.desc(),
                    transformed_dx + i * group_offset_in));
          },
          workspace_size);
    }
#endif

    if (!is_sys_pad) {
      // reverse padded input
      std::vector<int> starts(X->dims().size(), 0);
      std::vector<int> axes(X->dims().size(), 0);

      for (size_t i = 0; i < X->dims().size(); ++i) {
        starts[i] = input_pad[2 * i];
        axes[i] = i;
      }
      if (X->dims().size() == 4) {
        paddle::operators::RemovePaddingSlice<Context, T, 4>(
            ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
      } else {
        paddle::operators::RemovePaddingSlice<Context, T, 5>(
            ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
      }
    }
    if (channel_last) {
      TransToChannelLast<Context, T>(ctx, &transformed_dX_channel, dX);
    }
  }
}

template <typename T, typename Context>
void DepthwiseConvCudnnGradGradKernel(
    const Context& ctx,
    paddle::optional<const DenseTensor&> input_grad_grad,
    paddle::optional<const DenseTensor&> filter_grad_grad,
    const DenseTensor& out_grad,
    const DenseTensor& input,
    const DenseTensor& filter,
    const std::vector<int>& strides,
    const std::vector<int>& paddings_t,
    const std::string& padding_algorithm,
    int groups,
    const std::vector<int>& dilations_t,
    const std::string& data_format,
    bool use_addto,
    int workspace_size_MB,
    bool exhaustive_search_t,
    bool fuse_relu,
    DenseTensor* out_grad_grad,
    DenseTensor* input_grad,
    DenseTensor* filter_grad) {
  ConvCudnnGradGradKernel<T>(ctx,
                             input_grad_grad,
                             filter_grad_grad,
                             out_grad,
                             input,
                             filter,
                             strides,
                             paddings_t,
                             padding_algorithm,
                             groups,
                             dilations_t,
                             data_format,
                             use_addto,
                             workspace_size_MB,
                             exhaustive_search_t,
                             out_grad_grad,
                             input_grad,
                             filter_grad);
}

template <typename T, typename Context>
void Conv3DCudnnGradGradKernel(
    const Context& ctx,
    paddle::optional<const DenseTensor&> input_grad_grad,
    paddle::optional<const DenseTensor&> filter_grad_grad,
    const DenseTensor& out_grad,
    const DenseTensor& input,
    const DenseTensor& filter,
    const std::vector<int>& strides,
    const std::vector<int>& paddings_t,
    const std::string& padding_algorithm,
    int groups,
    const std::vector<int>& dilations_t,
    const std::string& data_format,
    bool use_addto,
    int workspace_size_MB,
    bool exhaustive_search_t,
    DenseTensor* out_grad_grad,
    DenseTensor* input_grad,
    DenseTensor* filter_grad) {
  ConvCudnnGradGradKernel<T>(ctx,
                             input_grad_grad,
                             filter_grad_grad,
                             out_grad,
                             input,
                             filter,
                             strides,
                             paddings_t,
                             padding_algorithm,
                             groups,
                             dilations_t,
                             data_format,
                             use_addto,
                             workspace_size_MB,
                             exhaustive_search_t,
                             out_grad_grad,
                             input_grad,
                             filter_grad);
}

}  // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(conv2d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnGradGradKernel,
                   float,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(conv3d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnGradGradKernel,
                   float,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DepthwiseConvCudnnGradGradKernel,
                   float,
                   phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(conv3d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DepthwiseConvCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}

#else

PD_REGISTER_KERNEL(conv2d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(conv3d_grad_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DepthwiseConvCudnnGradGradKernel,
                   float,
                   double,
                   phi::dtype::float16) {}

#endif

#endif