norm_grad_kernel.cu 3.9 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
16

17
#include "paddle/phi/kernels/norm_grad_kernel.h"
H
hong 已提交
18 19 20 21 22 23 24 25
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
26
#include "paddle/phi/backends/gpu/gpu_context.h"
27
#include "paddle/phi/common/bfloat16.h"
28 29
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
H
hong 已提交
30

31
namespace phi {
H
hong 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x,
                                  const T* x_norm,
                                  const T* y_grad,
                                  const int pre,
                                  const int axis_n,
                                  const int post,
                                  T* x_grad) {
  using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
  typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage_sum;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    MT sum = 0.0;
    __shared__ MT row_sum;
    __shared__ MT row_sqrt_norm;
    __shared__ MT row_norm;

    auto base = (i / post) * post * axis_n + (i % post);

    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
      sum += static_cast<MT>(x[index]) * static_cast<MT>(y_grad[index]);
    }
    MT reduce_result = BlockReduce(temp_storage_sum).Sum(sum);

    if (threadIdx.x == 0) {
      row_sum = reduce_result;
      row_sqrt_norm = static_cast<MT>(x_norm[i]);
      row_norm = row_sqrt_norm * row_sqrt_norm;
    }
    __syncthreads();
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
      const MT x_ij = static_cast<MT>(x[index]);
      const MT dy_ij = static_cast<MT>(y_grad[index]);
      x_grad[index] =
          static_cast<T>((dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm);
    }
  }
}

template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
                    const DenseTensor& x,
                    const DenseTensor& norm,
H
hong 已提交
79
                    const DenseTensor& out_grad,
H
hong 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
                    int axis,
                    float epsilon,
                    bool is_test,
                    DenseTensor* x_grad) {
  auto* in_x = &x;
  auto* in_norm = &norm;
  auto* in_dy = &out_grad;
  auto* out_dx = x_grad;
  ctx.template Alloc<T>(out_dx);
  T* dx = out_dx->data<T>();
  const T* x_data = in_x->data<T>();
  const T* x_norm = in_norm->data<T>();
  const T* dy = in_dy->data<T>();

  auto xdim = in_x->dims();
  if (axis < 0) axis = xdim.size() + axis;
  int pre, n, post;
  funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);

#ifdef __HIPCC__
  const int block = 256;
#else
  const int block = 512;
#endif
  int max_threads = ctx.GetMaxPhysicalThreadCount();
  const int max_blocks = std::max(max_threads / block, 1);
  int grid = std::min(max_blocks, pre * post);
107 108
  NormalizeGradient<T, block>
      <<<grid, block, 0, ctx.stream()>>>(x_data, x_norm, dy, pre, n, post, dx);
H
hong 已提交
109 110
}

111
}  // namespace phi
H
hong 已提交
112

113
PD_REGISTER_KERNEL(norm_grad,
H
hong 已提交
114 115
                   GPU,
                   ALL_LAYOUT,
116
                   phi::NormGradKernel,
H
hong 已提交
117 118
                   float,
                   double,
119
                   phi::dtype::float16) {}