fused_dropout_common.h 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>

#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
23
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
24
#include "paddle/fluid/operators/math/functors.h"
25
#include "paddle/fluid/platform/aligned_vector.h"
26 27
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)

/**
 * get the threads for fused_residual_dropout_bias:
 * 1D blocks: blockDim.x = cols
 * 2D grids: gridDim.y = rows
 */
inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
    const platform::CUDADeviceContext &ctx, const uint32_t rows,
44 45
    const uint32_t cols, const int vec_size) {
  const uint32_t tmp_cols = cols / vec_size;
46 47 48 49 50 51 52 53 54 55 56 57 58
  int threads = std::max(
      static_cast<uint32_t>(32),
      std::min(tmp_cols, static_cast<uint32_t>(ctx.GetMaxThreadsPerBlock())));
  const auto blocks_x =
      std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
  const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
  platform::GpuLaunchConfig config;
  config.block_per_grid.x = blocks_x;
  config.block_per_grid.y = blocks_y;
  config.thread_per_block.x = threads;
  return config;
}

59 60 61 62 63 64 65
template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
                                        float *data);

template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
66 67 68
  data[0] = curand_uniform(state);
}

69 70 71
template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
72 73 74 75
  data[0] = curand_uniform(state);
  data[1] = curand_uniform(state);
}

76 77 78
template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
79 80 81 82 83 84 85
  float4 rand4 = curand_uniform4(state);
  data[0] = rand4.x;
  data[1] = rand4.y;
  data[2] = rand4.w;
  data[3] = rand4.z;
}

86 87 88 89 90
template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
                                           float *data) {
  RandVec<4>(state, data);
  RandVec<4>(state, data + 4);
91 92
}

93 94 95
template <typename T>
inline void SetZero(const platform::CUDADeviceContext &ctx, T *ptr,
                    const size_t size) {
96
  PADDLE_ENFORCE_GPU_SUCCESS(
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
      cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
}

/**
 * reduce the sum of 128 cols data by 8*VecSize warps
 */
template <typename T, int VecSize, int BlockSizeX, int BlockSizeY>
inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias,
                                      const int cols) {
  // save temporary sum to cache and do transpose
  __shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
  for (int i = 0; i < VecSize; i++) {
    cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
  }
  __syncthreads();
  // reduce sum
113
  T sum[2] = {static_cast<T>(0)};
114 115 116 117 118
  int tid = threadIdx.y * blockDim.x + threadIdx.x;
  int x = tid >> 5;  // warp id
  int y = tid & 31;  // thread id on warp 0~31

  // need BlockSizeX * VecSize warps
119
  for (int j = x; j < BlockSizeX * VecSize; j += 32) {
120 121 122
// reduce 128 to 32
#pragma unroll
    for (int i = 0; i < (BlockSizeY >> 5); i++) {
123
      sum[(j >> 5)] += cache[j][y + i * 32];
124 125 126
    }
  }

127
  int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
128
  // reduce 32 to 1
129 130 131
  for (int i = 0; i < reduce_num_pre_thread; i++) {
    sum[i] = WarpReduceSum(sum[i]);
  }
132 133

  // save sum to dbias
134 135 136 137 138 139 140
  if (y == 0 && x < BlockSizeX * VecSize) {
    for (int i = 0; i < reduce_num_pre_thread; i++) {
      int bias_id = blockIdx.x * BlockSizeX * VecSize + x + i * 32;
      if (bias_id < cols) {
        dbias[bias_id] = sum[i];
      }
    }
141 142 143
  }
}

144 145 146 147 148 149 150 151 152 153 154 155
template <typename T>
inline __device__ T GetFactor(const float dropout_prob,
                              const bool is_upscale_in_train,
                              const bool is_test) {
  T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
                                 : static_cast<T>(1.0f);
  if (is_test) {
    factor = is_upscale_in_train ? static_cast<T>(1.0f)
                                 : static_cast<T>(1.0f - dropout_prob);
  }
  return factor;
}
156 157
}  // namespace operators
}  // namespace paddle