instance_norm_grad_kernel.cu 25.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
16 17 18
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
19
#include "paddle/phi/core/tensor_utils.h"
20
#include "paddle/phi/kernels/funcs/math_function.h"
21
#include "paddle/phi/kernels/funcs/norm_utils.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#include "paddle/phi/kernels/gpu/instance_norm_utils.h"

namespace phi {
template <typename T, int BlockDim>
static __global__ void GradComputeDX(const T *dy,
                                     const BatchNormParamType<T> *scale,
                                     const BatchNormParamType<T> *mean,
                                     const T *x,
                                     const BatchNormParamType<T> *variance,
                                     const int C,
                                     const int sample_size,
                                     T *dx) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;
  BatchNormParamType<T> mean_val = mean[ncid];
  BatchNormParamType<T> inv_var_val = variance[ncid];
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
  __shared__ BatchNormParamType<T> dy_sum_val;
  __shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;
  BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
  BatchNormParamType<T> dy_x_sub_mean_sum =
      static_cast<BatchNormParamType<T>>(0);

  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    BatchNormParamType<T> dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
    dy_sum += dy_i;
    dy_x_sub_mean_sum +=
        dy_i * (static_cast<BatchNormParamType<T>>(x[i]) - mean_val);
  }
  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  dy_x_sub_mean_sum =
      BlockReduce(dy_x_sub_mean_storage).Reduce(dy_x_sub_mean_sum, cub::Sum());
  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
  }
  __syncthreads();
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
64
    dx[i] = static_cast<T>(
65 66 67 68
        (static_cast<BatchNormParamType<T>>(dy[i]) -
         dy_sum_val / static_cast<BatchNormParamType<T>>(sample_size) -
         (static_cast<BatchNormParamType<T>>(x[i]) - mean_val) *
             dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) *
69
        scale[c] * inv_var_val);
70 71 72
  }
}

73 74 75 76 77 78 79
static __device__ __forceinline__ float real_sqrt(float x) {
  return 1. / sqrtf(x);
}
static __device__ __forceinline__ double real_sqrt(double x) {
  return 1. / sqrt(x);
}

80
template <typename T, typename AccT, int BlockDim>
81
__global__ void DoubleGradComputeDX(const T *x,
82 83
                                    const AccT *mean,
                                    const AccT *variance,
84 85
                                    const T *ddx,
                                    const T *dy,
86 87
                                    const AccT *scale,
                                    const AccT *ddscale,
88 89 90 91 92 93 94 95 96
                                    int C,
                                    int sample_size,
                                    const double epsilon,
                                    T *dx) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;

97 98
  AccT mean_val = mean[ncid];
  AccT var_val = variance[ncid];
99

100
  typedef cub::BlockReduce<AccT, BlockDim> BlockReduce;
101 102 103 104 105
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
106 107 108 109 110 111 112 113 114 115 116
  __shared__ AccT dy_sum_val;
  __shared__ AccT ddx_sum_val;
  __shared__ AccT dy_mul_ddx_sum_val;
  __shared__ AccT dy_mul_x_sub_mean_sum_val;
  __shared__ AccT ddx_mul_x_sub_mean_sum_val;

  AccT dy_sum = 0;
  AccT ddx_sum = 0;
  AccT dy_mul_ddx_sum = 0;
  AccT dy_mul_x_sub_mean_sum = 0;
  AccT ddx_mul_x_sub_mean_sum = 0;
117
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
118 119 120
    AccT ddx_i = static_cast<AccT>(ddx[i]);
    AccT dy_i = static_cast<AccT>(dy[i]);
    AccT tmp = static_cast<AccT>(x[i]) - mean_val;
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

    dy_sum += dy_i;
    ddx_sum += ddx_i;
    dy_mul_ddx_sum += (ddx_i * dy_i);

    dy_mul_x_sub_mean_sum += (dy_i * tmp);
    ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
  }

  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
  dy_mul_ddx_sum =
      BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
  dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                              .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
  ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                               .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    ddx_sum_val = ddx_sum;
    dy_mul_ddx_sum_val = dy_mul_ddx_sum;
    dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
    ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
  }
  __syncthreads();

  if (ddx != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
150 151 152 153
      AccT tmp = static_cast<AccT>(dx[i]);
      tmp +=
          ((static_cast<AccT>(x[i]) - mean_val) * var_val * var_val * var_val /
               sample_size *
154 155 156 157
               (ddx_sum_val * dy_sum_val / sample_size - dy_mul_ddx_sum_val +
                3. * dy_mul_x_sub_mean_sum_val * var_val *
                    ddx_mul_x_sub_mean_sum_val * var_val / sample_size) +
           ddx_mul_x_sub_mean_sum_val * var_val / sample_size * var_val *
158
               var_val * (dy_sum_val / sample_size - static_cast<AccT>(dy[i])) +
159
           dy_mul_x_sub_mean_sum_val * var_val / sample_size * var_val *
160 161
               var_val *
               (ddx_sum_val / sample_size - static_cast<AccT>(ddx[i]))) *
162
          scale[c];
163
      dx[i] = static_cast<T>(tmp);
164 165 166 167 168
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
169 170 171 172 173 174 175
      AccT tmp = static_cast<AccT>(dx[i]);
      tmp += (static_cast<AccT>(dy[i]) * var_val -
              dy_sum_val / sample_size * var_val -
              (static_cast<AccT>(x[i]) - mean_val) * var_val *
                  dy_mul_x_sub_mean_sum_val * var_val / sample_size) *
             ddscale[c];
      dx[i] = static_cast<T>(tmp);
176 177 178 179
    }
  }
}

180
template <typename T, typename AccT, int BlockDim>
181
__global__ void DoubleGradComputeDDY(const T *x,
182 183 184 185
                                     const AccT *mean,
                                     const AccT *variance,
                                     const AccT *ddscale,
                                     const AccT *ddbias,
186
                                     const T *ddx,
187
                                     const AccT *scale,
188 189 190 191 192 193 194 195
                                     int C,
                                     int sample_size,
                                     const double epsilon,
                                     T *ddy) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;
196 197 198
  AccT mean_val = mean[ncid];
  AccT var_val = variance[ncid];
  typedef cub::BlockReduce<AccT, BlockDim> BlockReduce;
199 200
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
201 202
  __shared__ AccT ddx_sum_val;
  __shared__ AccT ddx_mul_x_sub_mean_sum_val;
203

204 205
  AccT ddx_sum = 0;
  AccT ddx_mul_x_sub_mean_sum = 0;
206
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
207
    AccT ddx_i = static_cast<AccT>(ddx[i]);
208
    ddx_sum += ddx_i;
209
    ddx_mul_x_sub_mean_sum += (ddx_i * (static_cast<AccT>(x[i]) - mean_val));
210 211 212 213 214 215 216 217 218 219 220
  }
  ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
  ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                               .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());
  if (threadIdx.x == 0) {
    ddx_sum_val = ddx_sum;
    ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
  }
  __syncthreads();
  if (ddx != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
221 222 223 224 225 226
      AccT tmp = static_cast<AccT>(ddy[i]);
      tmp += scale[c] * var_val *
             (static_cast<AccT>(ddx[i]) - ddx_sum_val / sample_size -
              (static_cast<AccT>(x[i]) - mean_val) * var_val *
                  ddx_mul_x_sub_mean_sum_val * var_val / sample_size);
      ddy[i] = static_cast<T>(tmp);
227 228 229 230 231
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
232 233 234
      AccT tmp = static_cast<AccT>(ddy[i]);
      tmp += (static_cast<AccT>(x[i]) - mean_val) * var_val * ddscale[c];
      ddy[i] = static_cast<T>(tmp);
235 236 237 238 239
    }
  }
  __syncthreads();
  if (ddbias != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
240
      ddy[i] = static_cast<T>(static_cast<AccT>(ddy[i]) + ddbias[c]);
241 242 243 244
    }
  }
}

245
template <typename T, typename AccT, int BlockDim>
246
__global__ void DoubleGradComputeDScale(const T *x,
247 248
                                        const AccT *mean,
                                        const AccT *variance,
249 250 251 252 253
                                        const T *ddx,
                                        const T *dy,
                                        int C,
                                        int sample_size,
                                        const double epsilon,
254
                                        AccT *dscale) {
255 256 257 258
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;
259 260 261
  AccT mean_val = mean[ncid];
  AccT var_val = variance[ncid];
  typedef cub::BlockReduce<AccT, BlockDim> BlockReduce;
262 263 264
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
265 266
  __shared__ AccT dy_sum_val;
  __shared__ AccT dy_mul_x_sub_mean_sum_val;
267

268 269
  AccT dy_sum = 0;
  AccT dy_mul_x_sub_mean_sum = 0;
270
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
271
    AccT dy_i = static_cast<AccT>(dy[i]);
272
    dy_sum += dy_i;
273
    dy_mul_x_sub_mean_sum += (dy_i * (static_cast<AccT>(x[i]) - mean_val));
274 275 276 277 278 279 280 281 282 283 284
  }
  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                              .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
  }
  __syncthreads();
  if (ddx != nullptr) {
285
    AccT dscale_tmp = 0;
286
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
287 288 289 290 291
      dscale_tmp +=
          static_cast<AccT>(ddx[i]) * var_val *
          (static_cast<AccT>(dy[i]) - dy_sum_val / sample_size -
           dy_mul_x_sub_mean_sum_val * (static_cast<AccT>(x[i]) - mean_val) *
               var_val * var_val / sample_size);
292 293 294 295 296 297 298 299 300
    }
    dscale_tmp = BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());
    if (threadIdx.x == 0) {
      dscale[ncid] += dscale_tmp;
    }
    __syncthreads();
  }
}

301 302 303
template <typename T, typename Context>
void InstanceNormGradKernel(const Context &dev_ctx,
                            const DenseTensor &x,
304
                            const paddle::optional<DenseTensor> &scale,
305 306
                            const DenseTensor &saved_mean,
                            const DenseTensor &saved_variance,
307
                            const DenseTensor &d_y,
308 309 310 311
                            float epsilon_f,
                            DenseTensor *d_x,
                            DenseTensor *d_scale,
                            DenseTensor *d_bias) {
312
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
313 314 315 316 317 318
  double epsilon = static_cast<double>(epsilon_f);
  const auto *scale_ptr = scale.get_ptr();

  const auto &x_dims = x.dims();

  int N, C, H, W, D;
319
  funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
320 321 322 323 324 325 326 327
  int NxC = N * C;

  DenseTensor x_tmp, d_y_tmp;
  x_tmp.ShareDataWith(x).Resize({1, NxC, H, W, D});
  d_y_tmp.ShareDataWith(d_y).Resize({1, NxC, H, W, D});

  dev_ctx.template Alloc<T>(d_x);
  if (d_scale && d_bias) {
328 329
    dev_ctx.template Alloc<AccT>(d_scale);
    dev_ctx.template Alloc<AccT>(d_bias);
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
  }
  if (scale_ptr) {
    PADDLE_ENFORCE_EQ(
        scale_ptr->dims().size(),
        1UL,
        phi::errors::InvalidArgument(
            "The `shape` in InstanceNormOp is invalid: "
            "the size of scale's dimensions must be equal to 1. But "
            "received: the size of scale's dimensions"
            "is [%d]",
            scale_ptr->dims().size()));
    PADDLE_ENFORCE_EQ(scale_ptr->dims()[0],
                      C,
                      phi::errors::InvalidArgument(
                          "The `shape` in InstanceNormOp is invalid: "
                          "the first dimension of scale must be equal to "
                          "Channels([%d]). But received: "
                          "the first dimension of scale is [%d],"
                          "the dimensions of scale is [%s], ",
                          C,
                          scale_ptr->dims()[0],
                          scale_ptr->dims()));
  }

354
  phi::funcs::SetConstant<GPUContext, AccT> set_constant;
355 356 357 358 359 360 361 362 363 364

  const int n = x.numel();
  const int block = 512;
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  const int max_blocks = std::max(max_threads / block, 1);
  const int grid = std::min(NxC, max_blocks);
  const int grid1 = (C + block - 1) / block;

  DenseTensor scale_tmp;
  scale_tmp.Resize({NxC});
365
  dev_ctx.template Alloc<AccT>(&scale_tmp);
366 367 368

  DenseTensor d_scale_tmp;
  d_scale_tmp.Resize({NxC});
369
  dev_ctx.template Alloc<AccT>(&d_scale_tmp);
370 371 372

  DenseTensor d_bias_tmp;
  d_bias_tmp.Resize({NxC});
373
  dev_ctx.template Alloc<AccT>(&d_bias_tmp);
374
  if (scale_ptr) {
375 376
    repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
        scale_ptr->data<AccT>(), scale_tmp.data<AccT>(), N, C);
377
  } else {
378
    set_constant(dev_ctx, &scale_tmp, static_cast<AccT>(1));
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
  }
  std::vector<int> dims;
  std::vector<int> strides;
  dims = {1, NxC, H, W, D};
  strides = {NxC * H * W * D, H * W * D, W * D, D, 1};

  if ((H * W * D) == 1) {
    phi::Copy(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x);
    phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
    functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
    functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
    return;
  }

#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t data_desc_;
  miopenTensorDescriptor_t in_param_desc_;

  PADDLE_ENFORCE_GPU_SUCCESS(
398
      phi::dynload::miopenCreateTensorDescriptor(&data_desc_));
399
  PADDLE_ENFORCE_GPU_SUCCESS(
400
      phi::dynload::miopenCreateTensorDescriptor(&in_param_desc_));
401 402 403 404 405
#else
  cudnnTensorDescriptor_t data_desc_;
  cudnnTensorDescriptor_t in_param_desc_;

  PADDLE_ENFORCE_GPU_SUCCESS(
406
      phi::dynload::cudnnCreateTensorDescriptor(&data_desc_));
407
  PADDLE_ENFORCE_GPU_SUCCESS(
408
      phi::dynload::cudnnCreateTensorDescriptor(&in_param_desc_));
409 410 411 412 413 414 415 416 417 418
#endif

  if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
    LOG(ERROR) << "Provided epsilon is smaller than "
               << "CUDNN_BN_MIN_EPSILON. Setting it to "
               << "CUDNN_BN_MIN_EPSILON instead.";
  }
  epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);

#ifdef PADDLE_WITH_HIP
419 420 421 422 423 424 425 426
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSetTensorDescriptor(
      data_desc_,
      CudnnDataType<T>::type,
      x_dims.size() > 3 ? x_dims.size() : 4,
      const_cast<int *>(dims.data()),
      const_cast<int *>(strides.data())));
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDeriveBNTensorDescriptor(
      in_param_desc_, data_desc_, miopenBNSpatial));
427
#else
428 429 430 431 432 433 434 435
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor(
      data_desc_,
      CudnnDataType<T>::type,
      x_dims.size() > 3 ? x_dims.size() : 4,
      dims.data(),
      strides.data()));
  PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor(
      in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL));
436 437 438 439 440
#endif
  const auto *saved_mean_data =
      saved_mean.template data<BatchNormParamType<T>>();
  const auto *saved_var_data =
      saved_variance.template data<BatchNormParamType<T>>();
441

442 443
  if (d_scale && d_bias) {
#ifdef PADDLE_WITH_HIP
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenBatchNormalizationBackward(
        dev_ctx.cudnn_handle(),
        miopenBNSpatial,
        CudnnDataType<T>::kOne(),
        CudnnDataType<T>::kZero(),
        CudnnDataType<T>::kOne(),
        CudnnDataType<T>::kZero(),
        data_desc_,
        x_tmp.template data<T>(),
        data_desc_,
        d_y_tmp.template data<T>(),
        data_desc_,
        d_x->template data<T>(),
        in_param_desc_,
        scale_tmp.template data<BatchNormParamType<T>>(),
        d_scale_tmp.template data<BatchNormParamType<T>>(),
        d_bias_tmp.template data<BatchNormParamType<T>>(),
        epsilon,
        saved_mean_data,
        saved_var_data));
464
#else
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBatchNormalizationBackward(
        dev_ctx.cudnn_handle(),
        CUDNN_BATCHNORM_SPATIAL,
        CudnnDataType<T>::kOne(),
        CudnnDataType<T>::kZero(),
        CudnnDataType<T>::kOne(),
        CudnnDataType<T>::kZero(),
        data_desc_,
        x_tmp.template data<T>(),
        data_desc_,
        d_y_tmp.template data<T>(),
        data_desc_,
        d_x->template data<T>(),
        in_param_desc_,
        scale_tmp.template data<BatchNormParamType<T>>(),
        d_scale_tmp.template data<BatchNormParamType<T>>(),
        d_bias_tmp.template data<BatchNormParamType<T>>(),
        epsilon,
        saved_mean_data,
        saved_var_data));
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
#endif
  } else {
    if (d_x) {
      GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
          d_y.data<T>(),
          scale_tmp.data<BatchNormParamType<T>>(),
          saved_mean_data,
          x.data<T>(),
          saved_var_data,
          C,
          H * W * D,
          d_x->data<T>());
    }
  }
  if (d_scale && d_bias) {
500 501 502 503
    add_param<AccT, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
        d_scale_tmp.data<AccT>(), d_scale->data<AccT>(), N, C);
    add_param<AccT, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
        d_bias_tmp.data<AccT>(), d_bias->data<AccT>(), N, C);
504 505 506 507
  }

#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_GPU_SUCCESS(
508
      phi::dynload::miopenDestroyTensorDescriptor(data_desc_));
509
  PADDLE_ENFORCE_GPU_SUCCESS(
510
      phi::dynload::miopenDestroyTensorDescriptor(in_param_desc_));
511 512
#else
  PADDLE_ENFORCE_GPU_SUCCESS(
513
      phi::dynload::cudnnDestroyTensorDescriptor(data_desc_));
514
  PADDLE_ENFORCE_GPU_SUCCESS(
515
      phi::dynload::cudnnDestroyTensorDescriptor(in_param_desc_));
516 517
#endif
}
518 519 520 521

template <typename T, typename Context>
void InstanceNormDoubleGradKernel(const Context &dev_ctx,
                                  const DenseTensor &x,
522
                                  const paddle::optional<DenseTensor> &scale,
523 524 525
                                  const DenseTensor &saved_mean,
                                  const DenseTensor &saved_variance,
                                  const DenseTensor &dy,
526 527 528
                                  const paddle::optional<DenseTensor> &ddx,
                                  const paddle::optional<DenseTensor> &ddscale,
                                  const paddle::optional<DenseTensor> &ddbias,
529 530 531 532
                                  float epsilon_f,
                                  DenseTensor *dx,
                                  DenseTensor *dscale,
                                  DenseTensor *ddy) {
533
  using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
534 535 536 537 538 539 540 541
  const auto *Scale = scale.get_ptr();
  const auto *ddX = ddx.get_ptr();
  const auto *ddScale = ddscale.get_ptr();
  const auto *ddBias = ddbias.get_ptr();
  const double epsilon = static_cast<double>(epsilon_f);
  const T *x_data = x.data<T>();
  const T *dy_data = dy.data<T>();
  const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());
542 543 544 545 546 547
  const AccT *ddscale_data =
      (ddScale == nullptr ? nullptr : ddScale->data<AccT>());
  const AccT *ddbias_data =
      (ddScale == nullptr ? nullptr : ddBias->data<AccT>());
  const AccT *mean_data = saved_mean.data<AccT>();
  const AccT *variance_data = saved_variance.data<AccT>();
548
  phi::funcs::SetConstant<GPUContext, T> set_zero;
549 550
  phi::funcs::SetConstant<GPUContext, AccT> set_zero_AccT;

551 552 553 554 555 556 557 558 559 560
  auto &x_dims = x.dims();
  int N, C, H, W, D;
  funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
  int NxC = N * C;
  const int n = x.numel();
  int sample_size = n / N / C;

  DenseTensor scale_tmp;
  if (!Scale) {
    scale_tmp.Resize({C});
561 562
    dev_ctx.template Alloc<AccT>(&scale_tmp);
    set_zero_AccT(dev_ctx, &scale_tmp, static_cast<AccT>(1));
563
  }
564
  const AccT *scale_data = Scale ? Scale->data<AccT>() : scale_tmp.data<AccT>();
565 566 567 568 569 570 571 572 573
  const int block = 512;
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  const int max_blocks = std::max(max_threads / block, 1);
  const int grid = NxC;
  const int grid1 = (C + block - 1) / block;

  if (dx) {
    T *dx_data = dev_ctx.template Alloc<T>(dx);
    set_zero(dev_ctx, dx, static_cast<T>(0));
574
    DoubleGradComputeDX<T, AccT, block>
575 576 577 578 579 580 581 582 583 584 585
        <<<grid, block, 0, dev_ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               scale_data,
                                               ddscale_data,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dx_data);
586 587 588 589
  }
  if (dscale) {
    DenseTensor dscale_tmp;
    dscale_tmp.Resize({NxC});
590 591 592
    dev_ctx.template Alloc<AccT>(&dscale_tmp);
    set_zero_AccT(dev_ctx, &dscale_tmp, static_cast<AccT>(0));
    AccT *dscale_tmp_data = dscale_tmp.data<AccT>();
593

594 595 596
    AccT *dscale_data = dev_ctx.template Alloc<AccT>(dscale);
    set_zero_AccT(dev_ctx, dscale, static_cast<AccT>(0));
    DoubleGradComputeDScale<T, AccT, block>
597 598 599 600 601 602 603 604 605
        <<<grid, block, 0, dev_ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dscale_tmp_data);
606 607
    add_param<AccT, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
        dscale_tmp.data<AccT>(), dscale->data<AccT>(), N, C);
608 609 610 611
  }
  if (ddy) {
    T *ddy_data = dev_ctx.template Alloc<T>(ddy);
    set_zero(dev_ctx, ddy, static_cast<T>(0));
612
    DoubleGradComputeDDY<T, AccT, block>
613 614 615 616 617 618 619 620 621 622 623
        <<<grid, block, 0, dev_ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddscale_data,
                                               ddbias_data,
                                               ddx_data,
                                               scale_data,
                                               C,
                                               sample_size,
                                               epsilon,
                                               ddy_data);
624 625
  }
}
626 627 628 629
}  // namespace phi

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
630 631 632 633 634 635
PD_REGISTER_KERNEL(instance_norm_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormGradKernel,
                   float,
                   phi::dtype::float16) {}
636 637 638 639
PD_REGISTER_KERNEL(instance_norm_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormDoubleGradKernel,
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
                   float,
                   phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(instance_norm_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(instance_norm_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormDoubleGradKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}
659 660 661 662 663 664
#else
PD_REGISTER_KERNEL(instance_norm_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormGradKernel,
                   float,
665 666
                   double,
                   phi::dtype::float16) {}
667 668 669 670 671
PD_REGISTER_KERNEL(instance_norm_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::InstanceNormDoubleGradKernel,
                   float,
672 673
                   double,
                   phi::dtype::float16) {}
674
#endif