fused_residual_dropout_bias.h 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/fused_dropout_common.h"

namespace paddle {
namespace operators {

/**
 * @brief The fused function called by every thread
 * VecSize can be 1, 2, 4 or 8
 */
26 27
template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm,
          bool Activation, typename Functor>
28 29 30
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
    const int row_id, const int col_id, const int cols,
    curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
31 32 33
    const T *__restrict__ src, const T *__restrict__ residual,
    const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test,
    typename details::MPTypeTrait<T>::Type *mean_val,
34
    typename details::MPTypeTrait<T>::Type *var_val, Functor act_func) {
35 36 37
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
38 39 40 41 42 43 44 45
  using U = typename details::MPTypeTrait<T>::Type;

  LoadT src_vec;
  LoadT residual_vec;
  LoadT bias_vec;
#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
    bias_vec[ii] = static_cast<T>(0);
46
    residual_vec[ii] = static_cast<T>(0);
47 48
  }
  // vectorize load data from global
49
  phi::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
50
  if (residual) {
51
    phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
52
  }
53 54

  if (bias) {
55
    phi::Load<T, VecSize>(&bias[col_id], &bias_vec);
56 57 58 59 60
  }

  MaskStoreT mask_vec;
  if (!is_test) {
    float rand[VecSize];
61
    RandVec<VecSize>(state, rand);
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
    }
  } else {
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(1);
    }
  }

  StoreT dest_vec;

#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
77 78 79 80
    T tmp = src_vec[ii] + bias_vec[ii];
    if (Activation) {
      tmp = act_func(tmp);
    }
81
    dest_vec[ii] =
82
        tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
83 84 85 86 87 88 89 90
    if (ComputeLayerNorm) {
      U tmp = static_cast<U>(dest_vec[ii]);
      *mean_val += tmp;
      *var_val += (tmp * tmp);
    }
  }

  // store result to global
91
  phi::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
92
  if (!is_test) {
93
    phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
94 95 96 97 98 99 100 101 102 103 104 105 106
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 * the src, residual, mask and dst shape is (rows, cols)
 * the bias shape is (1, cols)
 * is_test: only used in inference
 * mask: can be null if is_test=true
 */
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutBias(
    const size_t rows, const size_t cols, uint64_t seed,
107 108 109 110
    const float dropout_prob, const bool is_upscale_in_train,
    const T *__restrict__ src, const T *__restrict__ residual,
    const T *__restrict__ bias, MaskType *mask, T *dst, uint64_t increment,
    const bool is_test) {
111 112 113 114 115
  int col_id = blockDim.x * blockIdx.x + threadIdx.x;
  int row_id = blockIdx.y;
  int idx = row_id * cols + col_id;
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);
116
  const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
117
  phi::funcs::ReluFunctor<T> relu;
118 119 120
  for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
    for (int i = col_id * VecSize; i < cols;
         i += blockDim.x * gridDim.x * VecSize) {
121
      FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, false,
122
                                        phi::funcs::ReluFunctor<T>>(
123
          r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
124
          mask, is_test, nullptr, nullptr, relu);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    }
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 */
template <typename T, typename MaskType>
void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
                               const int increment, uint64_t seed,
                               const float dropout_prob, const bool is_test,
                               bool is_upscale_in_train, const T *src,
                               const T *residual, const T *bias,
                               MaskType *mask_data, T *dst,
                               const platform::CUDADeviceContext &ctx) {
  // dropout_prob == 1.0f
  if (std::abs(dropout_prob - 1.0f) < 1e-5) {
    if (residual == dst) return;
143
    auto cuda_place = ctx.GetPlace();
144 145 146
    memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T),
                 ctx.stream());
    if (!is_test) {
147
      SetZero<MaskType>(ctx, mask_data, rows * cols);
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    }
    return;
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
  if (cols % VecSize == 0) {
    FusedResidualDropoutBias<T, uint8_t, VecSize><<<
        config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
        rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
        bias, mask_data, dst, increment, is_test);
  } else {
    FusedResidualDropoutBias<
        T, uint8_t,
        1><<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
        rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
        bias, mask_data, dst, increment, is_test);
  }
}

/*
 * @brief calculate the grad of no bias
 */
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask,
                                         const T factor, const int64_t size,
                                         T *dx) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

178 179 180
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
181 182 183
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_vec;
    MaskLoadT mask_vec;
184 185
    phi::Load<T, VecSize>(&dout[i], &dout_vec);
    phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
186 187 188 189 190 191

    StoreT dx_vec;
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
    }
192
    phi::Store<T, VecSize>(dx_vec, &dx[i]);
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  }
}

/**
 * blocks(128 * 8)
 * 1. calculate the dx and reduce total rows to 128 rows
 * 2. save 128*8 temporary sum in 8*128 shared memory
 * 3. reduce the sum of 128 rows data by 8*VecSize warps
 */
template <typename T, typename MaskType, int BlockSizeX, int BlockSizeY,
          int VecSize>
__global__ void FusedResidualDropoutBiasGrad(const T *dout,
                                             const MaskType *mask,
                                             const T factor, const int64_t rows,
                                             const int64_t cols, T *dx,
                                             T *dbias) {
  int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

211 212 213
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
214 215 216 217 218 219 220 221 222

  T tmp_sum[VecSize] = {static_cast<T>(0)};
  // calculate the dx and temporary sum
  if (col_id * VecSize < cols) {
    for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
      int index = row_id * cols + col_id * VecSize;
      LoadT out_vec;
      MaskLoadT mask_vec;
      StoreT dx_vec;
223 224
      phi::Load<T, VecSize>(&dout[index], &out_vec);
      phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
225 226 227 228 229 230 231

#pragma unroll
      for (int i = 0; i < VecSize; i++) {
        dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
        tmp_sum[i] += out_vec[i];
      }

232
      phi::Store<T, VecSize>(dx_vec, &dx[index]);
233 234 235
    }
  }

236
  CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
}

/**
 * @brief to launch kernel FusedResidualDropoutBiasGradVec
 */
template <typename T, typename MaskType>
void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask,
                                   const float dropout_prob,
                                   const bool is_upscale_in_train,
                                   const uint32_t rows, const uint32_t cols,
                                   T *dx, T *dbias,
                                   const platform::CUDADeviceContext &ctx) {
  const T zero = static_cast<T>(0.0f);
  auto factor = dropout_prob == static_cast<float>(1.0f)
                    ? zero
                    : static_cast<T>(1.0f / (1.0f - dropout_prob));
  if (!is_upscale_in_train) {
    factor = static_cast<T>(1.0f);
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  if (dbias != nullptr) {
260 261 262
    const auto threads = 8;
    auto blocks = std::max(static_cast<uint32_t>(1),
                           (cols / real_vec_size + threads - 1) / threads);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    dim3 block_dim(threads, 128, 1);
    dim3 grid_dim(blocks, 1, 1);
    if (cols % VecSize == 0) {
      FusedResidualDropoutBiasGrad<
          T, MaskType, 8, 128,
          VecSize><<<grid_dim, block_dim, 0, ctx.stream()>>>(
          dout, mask, factor, rows, cols, dx, dbias);
    } else {
      FusedResidualDropoutBiasGrad<T, MaskType, 8, 128,
                                   1><<<grid_dim, block_dim, 0, ctx.stream()>>>(
          dout, mask, factor, rows, cols, dx, dbias);
    }
  } else {
    const uint64_t n = rows * cols;
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
    if (n % VecSize == 0) {
      FusedResidualDropoutGrad<T, MaskType, VecSize><<<
          config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
          dout, mask, factor, n, dx);
    } else {
      FusedResidualDropoutGrad<T, MaskType, 1><<<
          config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
          dout, mask, factor, n, dx);
    }
  }
}

}  // namespace operators
}  // namespace paddle