norm_utils.cu.h 30.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
20
#ifdef __NVCC__
21
#include "cub/cub.cuh"
22 23 24 25 26
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
27
#include "paddle/phi/common/layout.h"
28
#include "paddle/phi/kernels/funcs/math_function.h"
29
#include "paddle/phi/kernels/funcs/reduce_function.h"
30

31 32 33 34 35 36
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif

37 38
namespace phi {
namespace funcs {
39 40 41 42 43 44 45 46 47 48 49 50 51

// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
//          np.sum(dy, axis=(n,h,w)) -
//          np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
//          mean),
//          axis=(n,h,w)) * inv_var.pow(2) *
//          np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
//          NxHxW *
//          np.sum(ddx * (x - mean)) *
//          (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
//          np.sum(dy,
//          axis=(n,h,w)) * (x - mean) *
52
//          (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
53 54 55 56
//          inv_var
//          *
//          np.mean(dy, axis=(n,h,w)) -
//          inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
57
//          axis=(n,h,w)))
58

59
template <typename T, int BlockDim, phi::DataLayout layout>
60
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDX(
61 62 63 64 65 66 67 68 69 70 71 72
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddx,
    const T *dy,
    const T *scale,
    const T *ddscale,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
    T *dx) {
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T dy_sum_val;
  __shared__ T ddx_sum_val;
  __shared__ T dy_mul_ddx_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T mean_val = mean[i];
    T var_val = variance[i];
    T dy_sum = 0;
    T ddx_sum = 0;
    T dy_mul_ddx_sum = 0;
    T dy_mul_x_sub_mean_sum = 0;
    T ddx_mul_x_sub_mean_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
98
          layout == phi::DataLayout::kNCHW
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      T dy_i = dy[index];
      T tmp = x[index] - mean_val;

      dy_sum += dy_i;
      ddx_sum += ddx_i;
      dy_mul_ddx_sum += (ddx_i * dy_i);

      dy_mul_x_sub_mean_sum += (dy_i * tmp);
      ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
    }

    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
    dy_mul_ddx_sum =
        BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
    dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                                .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
    ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                                 .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      ddx_sum_val = ddx_sum;
      dy_mul_ddx_sum_val = dy_mul_ddx_sum;
      dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
      ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
134
            layout == phi::DataLayout::kNCHW
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dx[index] +=
            ((x[index] - mean_val) * var_val * var_val * var_val / inner_size *
                 (ddx_sum_val * dy_sum_val / inner_size - dy_mul_ddx_sum_val +
                  3. * dy_mul_x_sub_mean_sum_val * var_val *
                      ddx_mul_x_sub_mean_sum_val * var_val / inner_size) +
             ddx_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
                 var_val * (dy_sum_val / inner_size - dy[index]) +
             dy_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
                 var_val * (ddx_sum_val / inner_size - ddx[index])) *
            scale[i];
      }
    }
    __syncthreads();
    if (ddscale != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
153
            layout == phi::DataLayout::kNCHW
154 155 156
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
157
                      (x[index] - mean_val) * var_val * var_val *
158 159 160 161 162 163 164 165 166 167
                          dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
                     ddscale[i];
      }
    }
  }
}

// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
//           scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
//           np.mean(ddx * (x - mean), axis=(n,h,w)))
168
template <typename T, int BlockDim, phi::DataLayout layout>
169
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDDY(
170 171 172 173 174 175 176 177 178 179 180 181
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddscale,
    const T *ddbias,
    const T *ddx,
    const T *scale,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
    T *ddy) {
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T ddx_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T mean_val = mean[i];
    T var_val = variance[i];
    T ddx_sum = 0;
    T ddx_mul_x_sub_mean_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
198
          layout == phi::DataLayout::kNCHW
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      ddx_sum += ddx_i;
      ddx_mul_x_sub_mean_sum += (ddx_i * (x[index] - mean_val));
    }
    ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
    ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                                 .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      ddx_sum_val = ddx_sum;
      ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
218
            layout == phi::DataLayout::kNCHW
219 220 221 222 223 224 225 226 227 228 229 230
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += scale[i] * var_val *
                      (ddx[index] - ddx_sum_val / inner_size -
                       (x[index] - mean_val) * var_val *
                           ddx_mul_x_sub_mean_sum_val * var_val / inner_size);
      }
    }
    __syncthreads();
    if (ddscale != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
231
            layout == phi::DataLayout::kNCHW
232 233 234 235 236 237 238 239 240
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += (x[index] - mean_val) * var_val * ddscale[i];
      }
    }
    __syncthreads();
    if (ddbias != nullptr) {
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
241
            layout == phi::DataLayout::kNCHW
242 243 244 245 246 247 248 249 250 251 252
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        ddy[index] += ddbias[i];
      }
    }
  }
}

// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
//            inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
//            ddx
253
template <typename T, int BlockDim, phi::DataLayout layout>
254
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScale(
255 256 257 258 259 260 261 262 263
    const T *x,
    const T *mean,
    const T *variance,
    const T *ddx,
    const T *dy,
    const int N,
    const int C,
    const int sample_size,
    const double epsilon,
264
    T *dscale) {
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  const int outer_size = C;
  const int inner_size = N * sample_size;

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
  __shared__ T dy_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T dy_sum = 0;
    T dy_mul_x_sub_mean_sum = 0;
    T mean_val = mean[i];
    T var_val = variance[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
282
          layout == phi::DataLayout::kNCHW
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T dy_i = dy[index];
      dy_sum += dy_i;
      dy_mul_x_sub_mean_sum += (dy_i * (x[index] - mean_val));
    }
    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                                .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      T dscale_tmp = 0;
      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index =
303
            layout == phi::DataLayout::kNCHW
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
                ? (j / sample_size * C + i) * sample_size + j % sample_size
                : j * outer_size + i;
        dscale_tmp += ddx[index] * var_val *
                      (dy[index] - dy_sum_val / inner_size -
                       dy_mul_x_sub_mean_sum_val * (x[index] - mean_val) *
                           var_val * var_val / inner_size);
      }
      dscale_tmp =
          BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());

      if (threadIdx.x == 0) {
        dscale[i] += dscale_tmp;
      }
      __syncthreads();
    }
  }
}

// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
323
template <typename T, int BlockDim, phi::DataLayout layout>
324
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScaleWithGlobal(
325 326 327 328 329 330 331 332
    const T *ddx,
    const T *variance,
    const T *dy,
    const double epsilon,
    const int N,
    const int C,
    const int sample_size,
    T *dscale) {
333 334 335 336 337 338 339 340 341 342
  int outer_size = C;
  int inner_size = N * sample_size;
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ddx_mul_dy_storage;
  __shared__ T ddx_mul_dy_sum_val;
  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    T inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
    T ddx_mul_dy_sum = 0;
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index =
343
          layout == phi::DataLayout::kNCHW
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
              ? (j / sample_size * C + i) * sample_size + j % sample_size
              : j * outer_size + i;
      T ddx_i = ddx[index];
      T dy_i = dy[index];
      ddx_mul_dy_sum += (ddx_i * dy_i);
    }
    ddx_mul_dy_sum =
        BlockReduce(ddx_mul_dy_storage).Reduce(ddx_mul_dy_sum, cub::Sum());
    if (threadIdx.x == 0) {
      ddx_mul_dy_sum_val = ddx_mul_dy_sum;
    }
    __syncthreads();

    if (ddx != nullptr) {
      dscale[i] = inv_var_i * ddx_mul_dy_sum_val;
    }
  }
}

// math: dx = ddscale * dy * inv_var
364
template <typename T, phi::DataLayout layout>
365 366
__global__ void DoubleGradComputeDXWithGlobal(const T *dy,
                                              const T *ddscale,
367
                                              const T *variance,
368 369
                                              const double epsilon,
                                              const int C,
370
                                              const int sample_size,
371 372
                                              const int num,
                                              T *dx) {
373 374
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
375
  if (ddscale != nullptr) {
376 377
    for (int i = gid; i < num; i += stride) {
      const int c =
378
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
379
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
380 381 382 383 384 385 386
      dx[i] = dy[i] * ddscale[c] * inv_var;
    }
  }
}

// math: ddy = scale * ddx * inv_var + ddbias +
//             ddscale * (x - mean) * inv_var
387
template <typename T, phi::DataLayout layout>
388 389 390 391 392 393 394 395 396 397 398 399
__global__ void DoubleGradComputeDDYWithGlobal(const T *ddx,
                                               const T *scale,
                                               const T *mean,
                                               const T *variance,
                                               const T *x,
                                               const T *ddbias,
                                               const T *ddscale,
                                               const double epsilon,
                                               const int C,
                                               const int sample_size,
                                               const int num,
                                               T *ddy) {
400 401 402 403 404 405
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;

  if (ddx != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
406
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
407 408 409 410 411 412 413 414
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
      ddy[i] += ddx[i] * scale[c] * inv_var;
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
415
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
416 417 418 419 420 421 422 423
      T inv_var = 1.0 / sqrt(variance[c] + epsilon);
      ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c];
    }
  }
  __syncthreads();
  if (ddbias != nullptr) {
    for (int i = gid; i < num; i += stride) {
      const int c =
424
          layout == phi::DataLayout::kNCHW ? i / sample_size % C : i % C;
425
      ddy[i] += ddbias[c];
426 427 428 429 430
    }
  }
}

template <typename DeviceContext, typename T>
H
hong 已提交
431
void NormDoubleGradFunctor(const DeviceContext &ctx,
432
                           const DataLayout data_layout,
433 434 435 436 437 438 439
                           const phi::DenseTensor *X,
                           const phi::DenseTensor *Scale,
                           const phi::DenseTensor *dY,
                           const phi::DenseTensor *Saved_mean,
                           const phi::DenseTensor *Saved_variance,
                           const phi::DenseTensor *Mean,
                           const phi::DenseTensor *Variance,
440 441
                           const double epsilon,
                           const bool use_global_stats,
442 443 444 445 446 447
                           const phi::DenseTensor *ddX,
                           const phi::DenseTensor *ddScale,
                           const phi::DenseTensor *ddBias,
                           phi::DenseTensor *dX,
                           phi::DenseTensor *dScale,
                           phi::DenseTensor *ddY) {
448 449 450 451 452 453 454
  const T *x_data = X->data<T>();
  const T *dy_data = dY->data<T>();
  const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());

  const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
  const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data<T>());

H
hong 已提交
455
  phi::funcs::SetConstant<DeviceContext, T> set_constant;
456 457 458 459 460 461 462

  auto &x_dims = X->dims();
  const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
                                                  : x_dims[x_dims.size() - 1]);
  const int N = x_dims[0];
  const int num = X->numel();
  const int sample_size = num / N / C;
463
  phi::DenseTensor scale_tmp;
464
  if (!Scale) {
465 466
    scale_tmp.Resize({C});
    ctx.template Alloc<T>(&scale_tmp);
H
hong 已提交
467
    set_constant(ctx, &scale_tmp, static_cast<T>(1));
468 469
  }
  const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
470 471 472
#ifdef __HIPCC__
  const int block = 256;
#else
473
  const int block = 512;
474
#endif
H
hong 已提交
475
  int max_threads = ctx.GetMaxPhysicalThreadCount();
476 477 478 479 480 481
  const int max_blocks = std::max(max_threads / block, 1);
  int grid = std::min(C, max_blocks);
  int grid1 = (num + block - 1) / block;

  const T *mean_data, *variance_data;
  if (use_global_stats) {
H
hong 已提交
482 483
    const auto *running_mean = Mean;
    const auto *running_var = Variance;
484
    const auto *running_mean_data = running_mean->template data<T>();
485
    const auto *running_var_data = running_var->template data<T>();
486
    mean_data = running_mean_data;
487 488 489 490
    variance_data = running_var_data;
  } else {
    const T *smean_data = Saved_mean->data<T>();
    const T *svariance_data = Saved_variance->data<T>();
H
hong 已提交
491

492 493 494 495 496
    mean_data = smean_data;
    variance_data = svariance_data;
  }

  if (dX) {
497
    T *dx_data = ctx.template Alloc<T>(dX);
H
hong 已提交
498
    set_constant(ctx, dX, static_cast<T>(0));
499 500
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
501
        DoubleGradComputeDXWithGlobal<T, DataLayout::kNHWC>
502 503 504 505 506 507 508 509
            <<<grid1, block, 0, ctx.stream()>>>(dy_data,
                                                ddscale_data,
                                                variance_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                dx_data);
510
      } else {
511
        DoubleGradComputeDXWithGlobal<T, DataLayout::kNCHW>
512 513 514 515 516 517 518 519
            <<<grid1, block, 0, ctx.stream()>>>(dy_data,
                                                ddscale_data,
                                                variance_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                dx_data);
520 521 522
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
523
        DoubleGradComputeDX<T, block, DataLayout::kNHWC>
524 525 526 527 528 529 530 531 532 533 534 535
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               scale_data,
                                               ddscale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dx_data);
536
      } else {
537
        DoubleGradComputeDX<T, block, DataLayout::kNCHW>
538 539 540 541 542 543 544 545 546 547 548 549
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               scale_data,
                                               ddscale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dx_data);
550 551 552 553
      }
    }
  }
  if (dScale) {
554
    T *dscale_data = ctx.template Alloc<T>(dScale);
H
hong 已提交
555
    set_constant(ctx, dScale, static_cast<T>(0));
556 557
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
558
        DoubleGradComputeDScaleWithGlobal<T, block, DataLayout::kNHWC>
559 560 561 562 563 564 565
            <<<grid, block, 0, ctx.stream()>>>(ddx_data,
                                               variance_data,
                                               dy_data,
                                               epsilon,
                                               N,
                                               C,
                                               sample_size,
566
                                               dscale_data);
567
      } else {
568
        DoubleGradComputeDScaleWithGlobal<T, block, DataLayout::kNCHW>
569 570 571 572 573 574 575
            <<<grid, block, 0, ctx.stream()>>>(ddx_data,
                                               variance_data,
                                               dy_data,
                                               epsilon,
                                               N,
                                               C,
                                               sample_size,
576
                                               dscale_data);
577 578 579
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
580
        DoubleGradComputeDScale<T, block, DataLayout::kNHWC>
581 582 583 584 585 586 587 588 589 590
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dscale_data);
591
      } else {
592
        DoubleGradComputeDScale<T, block, DataLayout::kNCHW>
593 594 595 596 597 598 599 600 601 602
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddx_data,
                                               dy_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               dscale_data);
603 604 605 606
      }
    }
  }
  if (ddY) {
607
    T *ddy_data = ctx.template Alloc<T>(ddY);
H
hong 已提交
608
    set_constant(ctx, ddY, static_cast<T>(0));
609 610
    if (use_global_stats) {
      if (data_layout == DataLayout::kNHWC) {
611
        DoubleGradComputeDDYWithGlobal<T, DataLayout::kNHWC>
612 613 614 615 616 617 618 619 620 621 622 623
            <<<grid1, block, 0, ctx.stream()>>>(ddx_data,
                                                scale_data,
                                                mean_data,
                                                variance_data,
                                                x_data,
                                                ddbias_data,
                                                ddscale_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                ddy_data);
624
      } else {
625
        DoubleGradComputeDDYWithGlobal<T, DataLayout::kNCHW>
626 627 628 629 630 631 632 633 634 635 636 637
            <<<grid1, block, 0, ctx.stream()>>>(ddx_data,
                                                scale_data,
                                                mean_data,
                                                variance_data,
                                                x_data,
                                                ddbias_data,
                                                ddscale_data,
                                                epsilon,
                                                C,
                                                sample_size,
                                                num,
                                                ddy_data);
638 639 640
      }
    } else {
      if (data_layout == DataLayout::kNHWC) {
641
        DoubleGradComputeDDY<T, block, DataLayout::kNHWC>
642 643 644 645 646 647 648 649 650 651 652 653
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddscale_data,
                                               ddbias_data,
                                               ddx_data,
                                               scale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               ddy_data);
654
      } else {
655
        DoubleGradComputeDDY<T, block, DataLayout::kNCHW>
656 657 658 659 660 661 662 663 664 665 666 667
            <<<grid, block, 0, ctx.stream()>>>(x_data,
                                               mean_data,
                                               variance_data,
                                               ddscale_data,
                                               ddbias_data,
                                               ddx_data,
                                               scale_data,
                                               N,
                                               C,
                                               sample_size,
                                               epsilon,
                                               ddy_data);
668 669 670 671
      }
    }
  }
}
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788

template <typename T, typename BnT>
__device__ __forceinline__ void BlockReduceByVetical(BnT x_sum,
                                                     BnT x_square_sum,
                                                     BnT *smem_sum,
                                                     BnT *smem_square_sum,
                                                     BnT *x_sum_out,
                                                     BnT *x_square_sum_out) {
  int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
  for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
    if (threadIdx.y < offset * 2) {
      smem_sum[tid] = x_sum;
      smem_square_sum[tid] = x_square_sum;
    }
    __syncthreads();
    if (threadIdx.y < offset) {
      int pair_tid = tid + offset * blockDim.x;
      x_sum += smem_sum[pair_tid];
      x_square_sum += smem_square_sum[pair_tid];
    }
  }
  if (threadIdx.y == 0) {
    *x_sum_out = x_sum;
    *x_square_sum_out = x_square_sum;
  }
}

template <typename T, typename BnT>
__device__ __forceinline__ void ReduceSumPost(const int C,  // channels
                                              const int c,  // channel index
                                              BnT *sum1,
                                              BnT *sum2,
                                              bool *is_last_block_done,
                                              BnT *cache1,
                                              BnT *cache2,
                                              BnT *block_data_ptr,
                                              int *flag_ptr) {
  volatile BnT *staging_sum = block_data_ptr;
  volatile BnT *staging_sum2 = &block_data_ptr[C * gridDim.y];
  // write block data to global memory
  if (threadIdx.y == 0) {
    staging_sum[c + blockIdx.y * C] = *sum1;
    staging_sum2[c + blockIdx.y * C] = *sum2;
  }

  // make sure write is visible to all blocks
  __threadfence();
  __syncthreads();

  // mark block done
  if (threadIdx.x == 0 && threadIdx.y == 0) {
    int old = atomicAdd(&flag_ptr[blockIdx.x], 1);
    *is_last_block_done = (old == (gridDim.y - 1));
  }

  __syncthreads();

  if (*is_last_block_done) {
    *sum1 = static_cast<BnT>(0);
    *sum2 = static_cast<BnT>(0);
    // thread sum
    for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
      *sum1 += staging_sum[c + y * C];
      *sum2 += staging_sum2[c + y * C];
    }

    // vertical block sum
    funcs::BlockReduceByVetical<T, BnT>(
        *sum1, *sum2, &cache1[0], &cache2[0], sum1, sum2);
  }
}

template <typename T, typename BnT, typename Context>
void SetLaunchConfigInfoForChannelLast(const Context &ctx,
                                       DenseTensor *block_data_tensor,
                                       DenseTensor *flag_tensor,
                                       BnT **block_data_ptr,
                                       int **flag_ptr,
                                       const int N,
                                       const int H,
                                       const int W,
                                       const int D,
                                       const int C,
                                       const int block_size,
                                       dim3 *block,
                                       dim3 *grid) {
  const int MAX_GRID_SIZE = 128;
  const int WARP_SIZE = 32;

  int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
  int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
                         block_size / block_x);
  if (block_x * block_y != block_size) {
    block_x =
        std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y);
  }
  int grid_x = (C + block_x - 1) / block_x;
  int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
                        MAX_GRID_SIZE);

  block->x = block_x;
  block->y = block_y;
  grid->x = grid_x;
  grid->y = grid_y;

  if (grid->y > 1) {
    *block_data_tensor = phi::Empty<BnT, Context>(ctx, {2 * C * grid->y});
    *flag_tensor = phi::Empty<int, Context>(ctx, {grid->x});

    *block_data_ptr = block_data_tensor->data<BnT>();
    *flag_ptr = flag_tensor->data<int>();
    funcs::SetConstant<Context, int> set_zero;
    set_zero(ctx, flag_tensor, static_cast<int>(0));
  }
}

789 790
}  // namespace funcs
}  // namespace phi