norm_utils.cu.h 26.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
20
#ifdef __NVCC__
21
#include "cub/cub.cuh"
22 23 24 25 26
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
27
#include "paddle/phi/common/layout.h"
28
#include "paddle/phi/kernels/funcs/math_function.h"
29

30 31 32 33 34 35
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif

36 37
namespace phi {
namespace funcs {
38

39
using DataLayout = phi::DataLayout;
40 41 42 43 44 45 46 47 48 49 50 51 52

// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
//          np.sum(dy, axis=(n,h,w)) -
//          np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
//          mean),
//          axis=(n,h,w)) * inv_var.pow(2) *
//          np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
//          NxHxW *
//          np.sum(ddx * (x - mean)) *
//          (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
//          np.sum(dy,
//          axis=(n,h,w)) * (x - mean) *
53
//          (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
54 55 56 57
//          inv_var
//          *
//          np.mean(dy, axis=(n,h,w)) -
//          inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
58
//          axis=(n,h,w)))
59

60
template <typename T, int BlockDim, phi::DataLayout layout>
61
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDX(
62 63 64 65 66 67 68 69 70 71 72 73
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddx,
    const T *dy,
    const T *scale,
    const T *ddscale,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
    T *dx) {
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T dy_sum_val;
  __shared__ T ddx_sum_val;
  __shared__ T dy_mul_ddx_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T mean_val = mean[i];
    T var_val = variance[i];
    T dy_sum = 0;
    T ddx_sum = 0;
    T dy_mul_ddx_sum = 0;
    T dy_mul_x_sub_mean_sum = 0;
    T ddx_mul_x_sub_mean_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
99
          layout == phi::DataLayout::kNCHW
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      T dy_i = dy[index];
      T tmp = x[index] - mean_val;

      dy_sum += dy_i;
      ddx_sum += ddx_i;
      dy_mul_ddx_sum += (ddx_i * dy_i);

      dy_mul_x_sub_mean_sum += (dy_i * tmp);
      ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
    }

    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
    dy_mul_ddx_sum =
        BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
    dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                                .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
    ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                                 .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      ddx_sum_val = ddx_sum;
      dy_mul_ddx_sum_val = dy_mul_ddx_sum;
      dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
      ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
135
            layout == phi::DataLayout::kNCHW
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dx[index] +=
            ((x[index] - mean_val) * var_val * var_val * var_val / inner_size *
                 (ddx_sum_val * dy_sum_val / inner_size - dy_mul_ddx_sum_val +
                  3. * dy_mul_x_sub_mean_sum_val * var_val *
                      ddx_mul_x_sub_mean_sum_val * var_val / inner_size) +
             ddx_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
                 var_val * (dy_sum_val / inner_size - dy[index]) +
             dy_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
                 var_val * (ddx_sum_val / inner_size - ddx[index])) *
            scale[i];
      }
    }
    __syncthreads();
    if (ddscale != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
154
            layout == phi::DataLayout::kNCHW
155 156 157
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
158
                      (x[index] - mean_val) * var_val * var_val *
159 160 161 162 163 164 165 166 167 168
                          dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
                     ddscale[i];
      }
    }
  }
}

// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
//           scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
//           np.mean(ddx * (x - mean), axis=(n,h,w)))
169
template <typename T, int BlockDim, phi::DataLayout layout>
170
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDDY(
171 172 173 174 175 176 177 178 179 180 181 182
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddscale,
    const T *ddbias,
    const T *ddx,
    const T *scale,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
    T *ddy) {
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T ddx_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T mean_val = mean[i];
    T var_val = variance[i];
    T ddx_sum = 0;
    T ddx_mul_x_sub_mean_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
199
          layout == phi::DataLayout::kNCHW
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      ddx_sum += ddx_i;
      ddx_mul_x_sub_mean_sum += (ddx_i * (x[index] - mean_val));
    }
    ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
    ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                                 .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      ddx_sum_val = ddx_sum;
      ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
219
            layout == phi::DataLayout::kNCHW
220 221 222 223 224 225 226 227 228 229 230 231
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += scale[i] * var_val *
                      (ddx[index] - ddx_sum_val / inner_size -
                       (x[index] - mean_val) * var_val *
                           ddx_mul_x_sub_mean_sum_val * var_val / inner_size);
      }
    }
    __syncthreads();
    if (ddscale != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
232
            layout == phi::DataLayout::kNCHW
233 234 235 236 237 238 239 240 241
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += (x[index] - mean_val) * var_val * ddscale[i];
      }
    }
    __syncthreads();
    if (ddbias != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
242
            layout == phi::DataLayout::kNCHW
243 244 245 246 247 248 249 250 251 252 253
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += ddbias[i];
      }
    }
  }
}

// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
//            inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
//            ddx
254
template <typename T, int BlockDim, phi::DataLayout layout>
255
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScale(
256 257 258 259 260 261 262 263 264
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddx,
    const T *dy,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
265
    T *dscale) {
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
  __shared__ T dy_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T dy_sum = 0;
    T dy_mul_x_sub_mean_sum = 0;
    T mean_val = mean[i];
    T var_val = variance[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
283
          layout == phi::DataLayout::kNCHW
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T dy_i = dy[index];
      dy_sum += dy_i;
      dy_mul_x_sub_mean_sum += (dy_i * (x[index] - mean_val));
    }
    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                                .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      T dscale_tmp = 0;
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
304
            layout == phi::DataLayout::kNCHW
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dscale_tmp += ddx[index] * var_val *
                      (dy[index] - dy_sum_val / inner_size -
                       dy_mul_x_sub_mean_sum_val * (x[index] - mean_val) *
                           var_val * var_val / inner_size);
      }
      dscale_tmp =
          BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());

      if (threadIdx.x == 0) {
        dscale[i] += dscale_tmp;
      }
      __syncthreads();
    }
  }
}

// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
324
template <typename T, int BlockDim, phi::DataLayout layout>
325
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScaleWithGlobal(
326 327 328 329 330 331 332 333
    const T *ddx,
    const T *variance,
    const T *dy,
    const double epsilon,
    const int N,
    const int C,
    const int sample_size,
    T *dscale) {
334 335 336 337 338 339 340 341 342 343
  int outer_size = C;
  int inner_size = N * sample_size;
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ddx_mul_dy_storage;
  __shared__ T ddx_mul_dy_sum_val;
  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
    T ddx_mul_dy_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
344
          layout == phi::DataLayout::kNCHW
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      T dy_i = dy[index];
      ddx_mul_dy_sum += (ddx_i * dy_i);
    }
    ddx_mul_dy_sum =
        BlockReduce(ddx_mul_dy_storage).Reduce(ddx_mul_dy_sum, cub::Sum());
    if (threadIdx.x == 0) {
      ddx_mul_dy_sum_val = ddx_mul_dy_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      dscale[i] = inv_var_i * ddx_mul_dy_sum_val;
    }
  }
}

// math: dx = ddscale * dy * inv_var
365
template <typename T, phi::DataLayout layout>
366 367
__global__ void DoubleGradComputeDXWithGlobal(const T *dy,
                                              const T *ddscale,
368
                                              const T *variance,
369 370
                                              const double epsilon,
                                              const int C,
371
                                              const int sample_size,
372 373
                                              const int num,
                                              T *dx) {
374 375
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
376
  if (ddscale != nullptr) {
377 378
    for (int i = gid; i < num; i += stride) {
      const int c =
379
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
380
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
381 382 383 384 385 386 387
      dx[i] = dy[i] * ddscale[c] * inv_var;
    }
  }
}

// math: ddy = scale * ddx * inv_var + ddbias +
//             ddscale * (x - mean) * inv_var
388
template <typename T, phi::DataLayout layout>
389 390 391 392 393 394 395 396 397 398 399 400
__global__ void DoubleGradComputeDDYWithGlobal(const T *ddx,
                                               const T *scale,
                                               const T *mean,
                                               const T *variance,
                                               const T *x,
                                               const T *ddbias,
                                               const T *ddscale,
                                               const double epsilon,
                                               const int C,
                                               const int sample_size,
                                               const int num,
                                               T *ddy) {
401 402 403 404 405 406
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;

  if (ddx != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
407
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
408 409 410 411 412 413 414 415
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
      ddy[i] += ddx[i] * scale[c] * inv_var;
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
416
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
417 418 419 420 421 422 423 424
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
      ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c];
    }
  }
  __syncthreads();
  if (ddbias != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
425
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
426
      ddy[i] += ddbias[c];
427 428 429 430 431
    }
  }
}

template <typename DeviceContext, typename T>
H
hong 已提交
432
void NormDoubleGradFunctor(const DeviceContext &ctx,
433
                           const DataLayout data_layout,
434 435 436 437 438 439 440
                           const phi::DenseTensor *X,
                           const phi::DenseTensor *Scale,
                           const phi::DenseTensor *dY,
                           const phi::DenseTensor *Saved_mean,
                           const phi::DenseTensor *Saved_variance,
                           const phi::DenseTensor *Mean,
                           const phi::DenseTensor *Variance,
441 442
                           const double epsilon,
                           const bool use_global_stats,
443 444 445 446 447 448
                           const phi::DenseTensor *ddX,
                           const phi::DenseTensor *ddScale,
                           const phi::DenseTensor *ddBias,
                           phi::DenseTensor *dX,
                           phi::DenseTensor *dScale,
                           phi::DenseTensor *ddY) {
449 450 451 452 453 454 455
  const T *x_data = X->data<T>();
  const T *dy_data = dY->data<T>();
  const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());

  const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
  const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data<T>());

H
hong 已提交
456
  phi::funcs::SetConstant<DeviceContext, T> set_constant;
457 458 459 460 461 462 463

  auto &x_dims = X->dims();
  const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
                                                  : x_dims[x_dims.size() - 1]);
  const int N = x_dims[0];
  const int num = X->numel();
  const int sample_size = num / N / C;
464
  phi::DenseTensor scale_tmp;
465
  if (!Scale) {
466 467
    scale_tmp.Resize({C});
    ctx.template Alloc<T>(&scale_tmp);
H
hong 已提交
468
    set_constant(ctx, &scale_tmp, static_cast<T>(1));
469 470
  }
  const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
471 472 473
#ifdef __HIPCC__
  const int block = 256;
#else
474
  const int block = 512;
475
#endif
H
hong 已提交
476
  int max_threads = ctx.GetMaxPhysicalThreadCount();
477 478 479 480 481 482
  const int max_blocks = std::max(max_threads / block, 1);
  int grid = std::min(C, max_blocks);
  int grid1 = (num + block - 1) / block;

  const T *mean_data, *variance_data;
  if (use_global_stats) {
H
hong 已提交
483 484
    const auto *running_mean = Mean;
    const auto *running_var = Variance;
485
    const auto *running_mean_data = running_mean->template data<T>();
486
    const auto *running_var_data = running_var->template data<T>();
487
    mean_data = running_mean_data;
488 489 490 491
    variance_data = running_var_data;
  } else {
    const T *smean_data = Saved_mean->data<T>();
    const T *svariance_data = Saved_variance->data<T>();
H
hong 已提交
492

493 494 495 496 497
    mean_data = smean_data;
    variance_data = svariance_data;
  }

  if (dX) {
498
    T *dx_data = ctx.template Alloc<T>(dX);
H
hong 已提交
499
    set_constant(ctx, dX, static_cast<T>(0));
500 501
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
502
        DoubleGradComputeDXWithGlobal<T, DataLayout::kNHWC>
503 504 505 506 507 508 509 510
            <<<grid1, block, 0, ctx.stream()>>>(dy_data,
                                                ddscale_data,
                                                variance_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                dx_data);
511
      } else {
512
        DoubleGradComputeDXWithGlobal<T, DataLayout::kNCHW>
513 514 515 516 517 518 519 520
            <<<grid1, block, 0, ctx.stream()>>>(dy_data,
                                                ddscale_data,
                                                variance_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                dx_data);
521 522 523
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
524
        DoubleGradComputeDX<T, block, DataLayout::kNHWC>
525 526 527 528 529 530 531 532 533 534 535 536
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               scale_data,
                                               ddscale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dx_data);
537
      } else {
538
        DoubleGradComputeDX<T, block, DataLayout::kNCHW>
539 540 541 542 543 544 545 546 547 548 549 550
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               scale_data,
                                               ddscale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dx_data);
551 552 553 554
      }
    }
  }
  if (dScale) {
555
    T *dscale_data = ctx.template Alloc<T>(dScale);
H
hong 已提交
556
    set_constant(ctx, dScale, static_cast<T>(0));
557 558
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
559
        DoubleGradComputeDScaleWithGlobal<T, block, DataLayout::kNHWC>
560 561 562 563 564 565 566
            <<<grid, block, 0, ctx.stream()>>>(ddx_data,
                                               variance_data,
                                               dy_data,
                                               epsilon,
                                               N,
                                               C,
                                               sample_size,
567
                                               dscale_data);
568
      } else {
569
        DoubleGradComputeDScaleWithGlobal<T, block, DataLayout::kNCHW>
570 571 572 573 574 575 576
            <<<grid, block, 0, ctx.stream()>>>(ddx_data,
                                               variance_data,
                                               dy_data,
                                               epsilon,
                                               N,
                                               C,
                                               sample_size,
577
                                               dscale_data);
578 579 580
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
581
        DoubleGradComputeDScale<T, block, DataLayout::kNHWC>
582 583 584 585 586 587 588 589 590 591
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dscale_data);
592
      } else {
593
        DoubleGradComputeDScale<T, block, DataLayout::kNCHW>
594 595 596 597 598 599 600 601 602 603
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dscale_data);
604 605 606 607
      }
    }
  }
  if (ddY) {
608
    T *ddy_data = ctx.template Alloc<T>(ddY);
H
hong 已提交
609
    set_constant(ctx, ddY, static_cast<T>(0));
610 611
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
612
        DoubleGradComputeDDYWithGlobal<T, DataLayout::kNHWC>
613 614 615 616 617 618 619 620 621 622 623 624
            <<<grid1, block, 0, ctx.stream()>>>(ddx_data,
                                                scale_data,
                                                mean_data,
                                                variance_data,
                                                x_data,
                                                ddbias_data,
                                                ddscale_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                ddy_data);
625
      } else {
626
        DoubleGradComputeDDYWithGlobal<T, DataLayout::kNCHW>
627 628 629 630 631 632 633 634 635 636 637 638
            <<<grid1, block, 0, ctx.stream()>>>(ddx_data,
                                                scale_data,
                                                mean_data,
                                                variance_data,
                                                x_data,
                                                ddbias_data,
                                                ddscale_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                ddy_data);
639 640 641
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
642
        DoubleGradComputeDDY<T, block, DataLayout::kNHWC>
643 644 645 646 647 648 649 650 651 652 653 654
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddscale_data,
                                               ddbias_data,
                                               ddx_data,
                                               scale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               ddy_data);
655
      } else {
656
        DoubleGradComputeDDY<T, block, DataLayout::kNCHW>
657 658 659 660 661 662 663 664 665 666 667 668
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddscale_data,
                                               ddbias_data,
                                               ddx_data,
                                               scale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               ddy_data);
669 670 671 672
      }
    }
  }
}
673 674
}  // namespace funcs
}  // namespace phi