fused_residual_dropout_bias.h 12.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/fused_dropout_common.h"

namespace paddle {
namespace operators {

/**
 * @brief The fused function called by every thread
 * VecSize can be 1, 2, 4 or 8
 */
26 27 28 29 30 31
template <typename T,
          typename MaskType,
          int VecSize,
          bool ComputeLayerNorm,
          bool Activation,
          typename Functor>
32
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
33 34 35 36 37 38 39 40 41 42 43 44
    const int row_id,
    const int col_id,
    const int cols,
    curandStatePhilox4_32_10_t *state,
    const float dropout_prob,
    const T factor,
    const T *__restrict__ src,
    const T *__restrict__ residual,
    const T *__restrict__ bias,
    T *dst,
    MaskType *mask,
    const bool is_test,
45
    typename details::MPTypeTrait<T>::Type *mean_val,
46 47
    typename details::MPTypeTrait<T>::Type *var_val,
    Functor act_func) {
48 49 50
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
51 52 53 54 55 56 57 58
  using U = typename details::MPTypeTrait<T>::Type;

  LoadT src_vec;
  LoadT residual_vec;
  LoadT bias_vec;
#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
    bias_vec[ii] = static_cast<T>(0);
59
    residual_vec[ii] = static_cast<T>(0);
60 61
  }
  // vectorize load data from global
62
  phi::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
63
  if (residual) {
64
    phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
65
  }
66 67

  if (bias) {
68
    phi::Load<T, VecSize>(&bias[col_id], &bias_vec);
69 70 71 72 73
  }

  MaskStoreT mask_vec;
  if (!is_test) {
    float rand[VecSize];
74
    RandVec<VecSize>(state, rand);
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
    }
  } else {
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      mask_vec[ii] = static_cast<MaskType>(1);
    }
  }

  StoreT dest_vec;

#pragma unroll
  for (int ii = 0; ii < VecSize; ii++) {
90 91 92 93
    T tmp = src_vec[ii] + bias_vec[ii];
    if (Activation) {
      tmp = act_func(tmp);
    }
94
    dest_vec[ii] =
95
        tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
96 97 98 99 100 101 102 103
    if (ComputeLayerNorm) {
      U tmp = static_cast<U>(dest_vec[ii]);
      *mean_val += tmp;
      *var_val += (tmp * tmp);
    }
  }

  // store result to global
104
  phi::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
105
  if (!is_test) {
106
    phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
107 108 109 110 111 112 113 114 115 116 117
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 * the src, residual, mask and dst shape is (rows, cols)
 * the bias shape is (1, cols)
 * is_test: only used in inference
 * mask: can be null if is_test=true
 */
template <typename T, typename MaskType, int VecSize>
118 119 120 121 122 123 124 125 126 127 128 129
__global__ void FusedResidualDropoutBias(const size_t rows,
                                         const size_t cols,
                                         uint64_t seed,
                                         const float dropout_prob,
                                         const bool is_upscale_in_train,
                                         const T *__restrict__ src,
                                         const T *__restrict__ residual,
                                         const T *__restrict__ bias,
                                         MaskType *mask,
                                         T *dst,
                                         uint64_t increment,
                                         const bool is_test) {
130 131 132 133 134
  int col_id = blockDim.x * blockIdx.x + threadIdx.x;
  int row_id = blockIdx.y;
  int idx = row_id * cols + col_id;
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);
135
  const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
136
  phi::funcs::ReluFunctor<T> relu;
137 138 139
  for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
    for (int i = col_id * VecSize; i < cols;
         i += blockDim.x * gridDim.x * VecSize) {
140 141 142 143 144
      FusedResidualDropoutBiasOneThread<T,
                                        MaskType,
                                        VecSize,
                                        false,
                                        false,
145
                                        phi::funcs::ReluFunctor<T>>(
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
          r,
          i,
          cols,
          &state,
          dropout_prob,
          factor,
          src,
          residual,
          bias,
          dst,
          mask,
          is_test,
          nullptr,
          nullptr,
          relu);
161 162 163 164 165 166 167 168
    }
  }
}

/**
 * @brief dst = residual + dropout(src + bias);
 */
template <typename T, typename MaskType>
169 170 171 172 173 174 175 176 177 178 179 180
void LaunchResidualDropoutBias(const uint32_t rows,
                               const uint32_t cols,
                               const int increment,
                               uint64_t seed,
                               const float dropout_prob,
                               const bool is_test,
                               bool is_upscale_in_train,
                               const T *src,
                               const T *residual,
                               const T *bias,
                               MaskType *mask_data,
                               T *dst,
181 182 183 184
                               const platform::CUDADeviceContext &ctx) {
  // dropout_prob == 1.0f
  if (std::abs(dropout_prob - 1.0f) < 1e-5) {
    if (residual == dst) return;
185
    if (residual) {
186 187 188 189 190 191
      memory::Copy(ctx.GetPlace(),
                   dst,
                   ctx.GetPlace(),
                   residual,
                   rows * cols * sizeof(T),
                   ctx.stream());
192 193 194
    } else {
      SetZero<T>(ctx, dst, rows * cols);
    }
195
    if (!is_test) {
196
      SetZero<MaskType>(ctx, mask_data, rows * cols);
197 198 199 200 201 202 203 204
    }
    return;
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
  if (cols % VecSize == 0) {
205 206
    FusedResidualDropoutBias<T, uint8_t, VecSize>
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
207 208 209 210 211 212 213 214 215 216 217 218
            rows,
            cols,
            seed,
            dropout_prob,
            is_upscale_in_train,
            src,
            residual,
            bias,
            mask_data,
            dst,
            increment,
            is_test);
219
  } else {
220 221
    FusedResidualDropoutBias<T, uint8_t, 1>
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
222 223 224 225 226 227 228 229 230 231 232 233
            rows,
            cols,
            seed,
            dropout_prob,
            is_upscale_in_train,
            src,
            residual,
            bias,
            mask_data,
            dst,
            increment,
            is_test);
234 235 236 237 238 239 240
  }
}

/*
 * @brief calculate the grad of no bias
 */
template <typename T, typename MaskType, int VecSize>
241 242 243 244
__global__ void FusedResidualDropoutGrad(const T *dout,
                                         const MaskType *mask,
                                         const T factor,
                                         const int64_t size,
245 246 247
                                         T *dx) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

248 249 250
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
251 252 253
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_vec;
    MaskLoadT mask_vec;
254 255
    phi::Load<T, VecSize>(&dout[i], &dout_vec);
    phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
256 257 258 259 260 261

    StoreT dx_vec;
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
    }
262
    phi::Store<T, VecSize>(dx_vec, &dx[i]);
263 264 265 266 267 268 269 270 271
  }
}

/**
 * blocks(128 * 8)
 * 1. calculate the dx and reduce total rows to 128 rows
 * 2. save 128*8 temporary sum in 8*128 shared memory
 * 3. reduce the sum of 128 rows data by 8*VecSize warps
 */
272 273 274 275
template <typename T,
          typename MaskType,
          int BlockSizeX,
          int BlockSizeY,
276 277 278
          int VecSize>
__global__ void FusedResidualDropoutBiasGrad(const T *dout,
                                             const MaskType *mask,
279 280 281 282
                                             const T factor,
                                             const int64_t rows,
                                             const int64_t cols,
                                             T *dx,
283 284 285
                                             T *dbias) {
  int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

286 287 288
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
289 290 291 292 293 294 295 296 297

  T tmp_sum[VecSize] = {static_cast<T>(0)};
  // calculate the dx and temporary sum
  if (col_id * VecSize < cols) {
    for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
      int index = row_id * cols + col_id * VecSize;
      LoadT out_vec;
      MaskLoadT mask_vec;
      StoreT dx_vec;
298 299
      phi::Load<T, VecSize>(&dout[index], &out_vec);
      phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
300 301 302 303 304 305 306

#pragma unroll
      for (int i = 0; i < VecSize; i++) {
        dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
        tmp_sum[i] += out_vec[i];
      }

307
      phi::Store<T, VecSize>(dx_vec, &dx[index]);
308 309 310
    }
  }

311
  CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
312 313 314 315 316 317
}

/**
 * @brief to launch kernel FusedResidualDropoutBiasGradVec
 */
template <typename T, typename MaskType>
318 319
void LaunchResidualDropoutBiasGrad(const T *dout,
                                   const MaskType *mask,
320 321
                                   const float dropout_prob,
                                   const bool is_upscale_in_train,
322 323 324 325
                                   const uint32_t rows,
                                   const uint32_t cols,
                                   T *dx,
                                   T *dbias,
326 327 328 329 330 331 332 333 334 335 336 337
                                   const platform::CUDADeviceContext &ctx) {
  const T zero = static_cast<T>(0.0f);
  auto factor = dropout_prob == static_cast<float>(1.0f)
                    ? zero
                    : static_cast<T>(1.0f / (1.0f - dropout_prob));
  if (!is_upscale_in_train) {
    factor = static_cast<T>(1.0f);
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  if (dbias != nullptr) {
338 339 340
    const auto threads = 8;
    auto blocks = std::max(static_cast<uint32_t>(1),
                           (cols / real_vec_size + threads - 1) / threads);
341 342 343
    dim3 block_dim(threads, 128, 1);
    dim3 grid_dim(blocks, 1, 1);
    if (cols % VecSize == 0) {
344
      FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, VecSize>
345 346
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(
              dout, mask, factor, rows, cols, dx, dbias);
347
    } else {
348
      FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1>
349 350
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(
              dout, mask, factor, rows, cols, dx, dbias);
351 352 353 354 355 356
    }
  } else {
    const uint64_t n = rows * cols;
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
    if (n % VecSize == 0) {
357 358 359
      FusedResidualDropoutGrad<T, MaskType, VecSize>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              dout, mask, factor, n, dx);
360
    } else {
361 362 363
      FusedResidualDropoutGrad<T, MaskType, 1>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              dout, mask, factor, n, dx);
364 365 366 367 368 369
    }
  }
}

}  // namespace operators
}  // namespace paddle