fused_residual_dropout_bias.h 15.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/fused_dropout_common.h"

namespace paddle {
namespace operators {

/**
 * @brief The fused function called by every thread
 * VecSize can be 1, 2, 4 or 8
 */
26 27 28 29 30
template <typename T,
          typename MaskType,
          int VecSize,
          bool ComputeLayerNorm,
          bool Activation,
31 32 33
          typename Functor,
          typename InType = T,
          typename OutType = T>
34
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
35 36 37 38 39 40
    const int row_id,
    const int col_id,
    const int cols,
    curandStatePhilox4_32_10_t *state,
    const float dropout_prob,
    const T factor,
41
    const InType *__restrict__ src,
42 43
    const T *__restrict__ residual,
    const T *__restrict__ bias,
44
    OutType *dst,
45 46
    MaskType *mask,
    const bool is_test,
47
    typename details::MPTypeTrait<T>::Type *mean_val,
48
    typename details::MPTypeTrait<T>::Type *var_val,
49 50 51 52 53 54 55 56
    Functor act_func,
    const float quant_last_in_scale = 1.0,
    const float *dequant_out_scale_data = nullptr,
    const int quant_out_scale_offset = 0,
    const float quant_next_in_scale = 1.0,
    const int quant_round_type = 1,
    const float quant_max_bound = 127.0,
    const float quant_min_bound = -127.0) {
57
  using LoadT = phi::AlignedVector<T, VecSize>;
58 59
  using LoadInType = phi::AlignedVector<InType, VecSize>;
  using LoadFloat = phi::AlignedVector<float, VecSize>;
60
  using StoreT = phi::AlignedVector<T, VecSize>;
61 62
  using StoreOutType = phi::AlignedVector<OutType, VecSize>;

63
  using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
64 65
  using U = typename details::MPTypeTrait<T>::Type;

66
  LoadInType src_vec;
67 68
  LoadT residual_vec;
  LoadT bias_vec;
69
  LoadFloat quant_out_scale_vec;
70 71 72
#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
    bias_vec[ii] = static_cast<T>(0);
73
    residual_vec[ii] = static_cast<T>(0);
74 75
  }
  // vectorize load data from global
76 77 78 79
  phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
  phi::Load<float, VecSize>(
      &dequant_out_scale_data[quant_out_scale_offset + col_id],
      &quant_out_scale_vec);
80
  if (residual) {
81
    phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
82
  }
83 84

  if (bias) {
85
    phi::Load<T, VecSize>(&bias[col_id], &bias_vec);
86 87 88 89 90
  }

  MaskStoreT mask_vec;
  if (!is_test) {
    float rand[VecSize];
91
    RandVec<VecSize>(state, rand);
92 93 94 95 96 97 98 99 100 101 102 103
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
    }
  } else {
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(1);
    }
  }

  StoreT dest_vec;
104
  StoreOutType dest_vec_out_type;
105 106 107

#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
108 109 110 111 112 113 114 115
    T tmp;
    if (std::is_same<InType, int32_t>::value) {
      T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
                              quant_last_in_scale / quant_out_scale_vec[ii]);
      tmp = tmp0 + bias_vec[ii];
    } else {
      tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
    }
116 117 118
    if (Activation) {
      tmp = act_func(tmp);
    }
119
    dest_vec[ii] =
120
        tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
121 122 123 124 125
    if (ComputeLayerNorm) {
      U tmp = static_cast<U>(dest_vec[ii]);
      *mean_val += tmp;
      *var_val += (tmp * tmp);
    }
126 127 128 129 130 131 132
    if (std::is_same<OutType, int8_t>::value) {
      dest_vec_out_type[ii] = quant_helper(dest_vec[ii],
                                           quant_next_in_scale,
                                           quant_round_type,
                                           quant_max_bound,
                                           quant_min_bound);
    }
133 134 135
  }

  // store result to global
136 137 138 139 140 141 142
  if (std::is_same<OutType, int8_t>::value) {
    phi::Store<OutType, VecSize>(dest_vec_out_type,
                                 &dst[row_id * cols + col_id]);
  } else {
    phi::Store<T, VecSize>(dest_vec,
                           reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
  }
143
  if (!is_test) {
144
    phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
145 146 147 148 149 150 151 152 153 154
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 * the src, residual, mask and dst shape is (rows, cols)
 * the bias shape is (1, cols)
 * is_test: only used in inference
 * mask: can be null if is_test=true
 */
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
template <typename T,
          typename MaskType,
          int VecSize,
          typename InType = T,
          typename OutType = T>
__global__ void FusedResidualDropoutBias(
    const size_t rows,
    const size_t cols,
    uint64_t seed,
    const float dropout_prob,
    const bool is_upscale_in_train,
    const InType *__restrict__ src,
    const T *__restrict__ residual,
    const T *__restrict__ bias,
    MaskType *mask,
    OutType *dst,
    uint64_t increment,
    const bool is_test,
    const float quant_last_in_scale = 1.0,
    const float *dequant_out_scale_data = nullptr,
    const int quant_out_scale_offset = 0,
    const float quant_next_in_scale = 1.0) {
177 178 179 180 181
  int col_id = blockDim.x * blockIdx.x + threadIdx.x;
  int row_id = blockIdx.y;
  int idx = row_id * cols + col_id;
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);
182
  const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
183
  phi::funcs::ReluFunctor<T> relu;
184 185 186
  for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
    for (int i = col_id * VecSize; i < cols;
         i += blockDim.x * gridDim.x * VecSize) {
187 188 189 190 191
      FusedResidualDropoutBiasOneThread<T,
                                        MaskType,
                                        VecSize,
                                        false,
                                        false,
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
                                        phi::funcs::ReluFunctor<T>,
                                        InType,
                                        OutType>(r,
                                                 i,
                                                 cols,
                                                 &state,
                                                 dropout_prob,
                                                 factor,
                                                 src,
                                                 residual,
                                                 bias,
                                                 dst,
                                                 mask,
                                                 is_test,
                                                 nullptr,
                                                 nullptr,
                                                 relu,
                                                 quant_last_in_scale,
                                                 dequant_out_scale_data,
                                                 quant_out_scale_offset,
                                                 quant_next_in_scale);
213 214 215 216 217 218 219
    }
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 */
220 221 222 223
template <typename T,
          typename MaskType,
          typename InType = T,
          typename OutType = T>
224 225 226 227 228 229 230
void LaunchResidualDropoutBias(const uint32_t rows,
                               const uint32_t cols,
                               const int increment,
                               uint64_t seed,
                               const float dropout_prob,
                               const bool is_test,
                               bool is_upscale_in_train,
231
                               const InType *src,
232 233 234
                               const T *residual,
                               const T *bias,
                               MaskType *mask_data,
235 236 237 238 239 240
                               OutType *dst,
                               const phi::GPUContext &ctx,
                               const float quant_last_in_scale = 1.0,
                               const float *dequant_out_scale_data = nullptr,
                               const int quant_out_scale_offset = 0,
                               const float quant_next_in_scale = 1.0) {
241 242
  // dropout_prob == 1.0f
  if (std::abs(dropout_prob - 1.0f) < 1e-5) {
243
    // NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
244
    if (residual == dst) return;
245
    if (residual) {
246 247 248 249 250 251
      memory::Copy(ctx.GetPlace(),
                   dst,
                   ctx.GetPlace(),
                   residual,
                   rows * cols * sizeof(T),
                   ctx.stream());
252 253 254
    } else {
      SetZero<T>(ctx, dst, rows * cols);
    }
255
    if (!is_test) {
256
      SetZero<MaskType>(ctx, mask_data, rows * cols);
257 258 259 260 261 262 263 264
    }
    return;
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
  if (cols % VecSize == 0) {
265
    FusedResidualDropoutBias<T, uint8_t, VecSize, InType, OutType>
266
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
267 268 269 270 271 272 273 274 275 276 277
            rows,
            cols,
            seed,
            dropout_prob,
            is_upscale_in_train,
            src,
            residual,
            bias,
            mask_data,
            dst,
            increment,
278 279 280 281 282
            is_test,
            quant_last_in_scale,
            dequant_out_scale_data,
            quant_out_scale_offset,
            quant_next_in_scale);
283
  } else {
284
    FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
285
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
286 287 288 289 290 291 292 293 294 295 296
            rows,
            cols,
            seed,
            dropout_prob,
            is_upscale_in_train,
            src,
            residual,
            bias,
            mask_data,
            dst,
            increment,
297 298 299 300 301
            is_test,
            quant_last_in_scale,
            dequant_out_scale_data,
            quant_out_scale_offset,
            quant_next_in_scale);
302 303 304 305 306 307 308
  }
}

/*
 * @brief calculate the grad of no bias
 */
template <typename T, typename MaskType, int VecSize>
309 310 311 312
__global__ void FusedResidualDropoutGrad(const T *dout,
                                         const MaskType *mask,
                                         const T factor,
                                         const int64_t size,
313 314 315
                                         T *dx) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

316 317 318
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
319 320 321
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_vec;
    MaskLoadT mask_vec;
322 323
    phi::Load<T, VecSize>(&dout[i], &dout_vec);
    phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
324 325 326 327 328 329

    StoreT dx_vec;
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
    }
330
    phi::Store<T, VecSize>(dx_vec, &dx[i]);
331 332 333 334 335 336 337 338 339
  }
}

/**
 * blocks(128 * 8)
 * 1. calculate the dx and reduce total rows to 128 rows
 * 2. save 128*8 temporary sum in 8*128 shared memory
 * 3. reduce the sum of 128 rows data by 8*VecSize warps
 */
340 341 342 343
template <typename T,
          typename MaskType,
          int BlockSizeX,
          int BlockSizeY,
344 345 346
          int VecSize>
__global__ void FusedResidualDropoutBiasGrad(const T *dout,
                                             const MaskType *mask,
347 348 349 350
                                             const T factor,
                                             const int64_t rows,
                                             const int64_t cols,
                                             T *dx,
351 352 353
                                             T *dbias) {
  int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

354 355 356
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
357 358 359 360 361 362 363 364 365

  T tmp_sum[VecSize] = {static_cast<T>(0)};
  // calculate the dx and temporary sum
  if (col_id * VecSize < cols) {
    for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
      int index = row_id * cols + col_id * VecSize;
      LoadT out_vec;
      MaskLoadT mask_vec;
      StoreT dx_vec;
366 367
      phi::Load<T, VecSize>(&dout[index], &out_vec);
      phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
368 369 370 371 372 373 374

#pragma unroll
      for (int i = 0; i < VecSize; i++) {
        dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
        tmp_sum[i] += out_vec[i];
      }

375
      phi::Store<T, VecSize>(dx_vec, &dx[index]);
376 377 378
    }
  }

379
  CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
380 381 382 383 384 385
}

/**
 * @brief to launch kernel FusedResidualDropoutBiasGradVec
 */
template <typename T, typename MaskType>
386 387
void LaunchResidualDropoutBiasGrad(const T *dout,
                                   const MaskType *mask,
388 389
                                   const float dropout_prob,
                                   const bool is_upscale_in_train,
390 391 392 393
                                   const uint32_t rows,
                                   const uint32_t cols,
                                   T *dx,
                                   T *dbias,
L
Leo Chen 已提交
394
                                   const phi::GPUContext &ctx) {
395 396 397 398 399 400 401 402 403 404 405
  const T zero = static_cast<T>(0.0f);
  auto factor = dropout_prob == static_cast<float>(1.0f)
                    ? zero
                    : static_cast<T>(1.0f / (1.0f - dropout_prob));
  if (!is_upscale_in_train) {
    factor = static_cast<T>(1.0f);
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  if (dbias != nullptr) {
406 407 408
    const auto threads = 8;
    auto blocks = std::max(static_cast<uint32_t>(1),
                           (cols / real_vec_size + threads - 1) / threads);
409 410 411
    dim3 block_dim(threads, 128, 1);
    dim3 grid_dim(blocks, 1, 1);
    if (cols % VecSize == 0) {
412
      FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, VecSize>
413 414
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(
              dout, mask, factor, rows, cols, dx, dbias);
415
    } else {
416
      FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1>
417 418
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(
              dout, mask, factor, rows, cols, dx, dbias);
419 420 421 422 423 424
    }
  } else {
    const uint64_t n = rows * cols;
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
    if (n % VecSize == 0) {
425 426 427
      FusedResidualDropoutGrad<T, MaskType, VecSize>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              dout, mask, factor, n, dx);
428
    } else {
429 430 431
      FusedResidualDropoutGrad<T, MaskType, 1>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              dout, mask, factor, n, dx);
432 433 434 435 436 437
    }
  }
}

}  // namespace operators
}  // namespace paddle