fused_dropout_common.h 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>

#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
23
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
24
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
25 26
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
27 28
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
29
#include "paddle/phi/kernels/funcs/aligned_vector.h"
30
#include "paddle/phi/kernels/funcs/functors.h"
31 32 33 34 35 36 37 38 39 40 41 42 43

namespace paddle {
namespace operators {

#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)

/**
 * get the threads for fused_residual_dropout_bias:
 * 1D blocks: blockDim.x = cols
 * 2D grids: gridDim.y = rows
 */
inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
L
Leo Chen 已提交
44
    const phi::GPUContext &ctx,
45 46 47
    const uint32_t rows,
    const uint32_t cols,
    const int vec_size) {
48
  const uint32_t tmp_cols = cols / vec_size;
49 50 51 52 53 54 55
  // NOTE(wangxi): We set max_block_size to 512, for `FusedResidualDropoutBias`
  // needs too many register resources. If data_type is float16, CUDA
  // error(701) will occur when block_size is 1024. Which error is
  // 'cudaErrorLaunchOutOfResources', this indicates that a launch did not
  // occur because it did not have appropriate resources.
  // Of course, this kernel can be optimized later to reduce the use
  // of registers.
56 57 58
  int threads = std::max(static_cast<uint32_t>(32),
                         std::min(tmp_cols,
                                  static_cast<uint32_t>(std::min(
59
                                      ctx.GetMaxThreadsPerBlock(), 512))));
60 61 62 63 64 65 66 67 68 69
  const auto blocks_x =
      std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
  const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
  platform::GpuLaunchConfig config;
  config.block_per_grid.x = blocks_x;
  config.block_per_grid.y = blocks_y;
  config.thread_per_block.x = threads;
  return config;
}

70 71 72 73 74 75 76
template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
                                        float *data);

template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
77 78 79
  data[0] = curand_uniform(state);
}

80 81 82
template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
83 84 85 86
  data[0] = curand_uniform(state);
  data[1] = curand_uniform(state);
}

87 88 89
template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
90 91 92 93 94 95 96
  float4 rand4 = curand_uniform4(state);
  data[0] = rand4.x;
  data[1] = rand4.y;
  data[2] = rand4.w;
  data[3] = rand4.z;
}

97 98 99 100 101
template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
  RandVec<4>(state, data);
  RandVec<4>(state, data + 4);
102 103
}

104
template <typename T>
L
Leo Chen 已提交
105
inline void SetZero(const phi::GPUContext &ctx, T *ptr, const size_t size) {
106
  PADDLE_ENFORCE_GPU_SUCCESS(
107 108 109 110 111 112 113
      cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
}

/**
 * reduce the sum of 128 cols data by 8*VecSize warps
 */
template <typename T, int VecSize, int BlockSizeX, int BlockSizeY>
114 115
inline __device__ void CalculateDBias(const T *tmp_sum,
                                      T *dbias,
116 117 118 119 120 121 122 123
                                      const int cols) {
  // save temporary sum to cache and do transpose
  __shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
  for (int i = 0; i < VecSize; i++) {
    cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
  }
  __syncthreads();
  // reduce sum
124
  T sum[2] = {static_cast<T>(0)};
125 126 127 128 129
  int tid = threadIdx.y * blockDim.x + threadIdx.x;
  int x = tid >> 5;  // warp id
  int y = tid & 31;  // thread id on warp 0~31

  // need BlockSizeX * VecSize warps
130
  for (int j = x; j < BlockSizeX * VecSize; j += 32) {
131 132 133
// reduce 128 to 32
#pragma unroll
    for (int i = 0; i < (BlockSizeY >> 5); i++) {
134
      sum[(j >> 5)] += cache[j][y + i * 32];
135 136 137
    }
  }

138
  int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
139
  // reduce 32 to 1
140 141 142
  for (int i = 0; i < reduce_num_pre_thread; i++) {
    sum[i] = WarpReduceSum(sum[i]);
  }
143 144

  // save sum to dbias
145 146 147 148 149 150 151
  if (y == 0 && x < BlockSizeX * VecSize) {
    for (int i = 0; i < reduce_num_pre_thread; i++) {
      int bias_id = blockIdx.x * BlockSizeX * VecSize + x + i * 32;
      if (bias_id < cols) {
        dbias[bias_id] = sum[i];
      }
    }
152 153 154
  }
}

155 156 157 158 159 160 161 162 163 164 165 166
template <typename T>
inline __device__ T GetFactor(const float dropout_prob,
                              const bool is_upscale_in_train,
                              const bool is_test) {
  T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
                                 : static_cast<T>(1.0f);
  if (is_test) {
    factor = is_upscale_in_train ? static_cast<T>(1.0f)
                                 : static_cast<T>(1.0f - dropout_prob);
  }
  return factor;
}
167 168
}  // namespace operators
}  // namespace paddle