fused_dropout_act_bias.h 16.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

20
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
21
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

namespace paddle {
namespace operators {

template <typename T>
struct GeluFunctor {
  inline __host__ __device__ T operator()(const T x) const {
    using U = LayerNormParamType<T>;
    const U casted_x = static_cast<U>(x);
    const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
    const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
    return static_cast<T>(out);
  }
};

37 38 39 40 41 42 43
template <typename T>
struct FastGeluFunctor {
  inline __device__ T operator()(const T x) const {
    return phi::GeluFwd<T, true>(x);
  }
};

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
/**
 *@brief the gelu grad functor
 */
template <typename T>
struct GeluGradFunctor {
  inline __host__ __device__ T UseOut(const T x) const {
    using U = LayerNormParamType<T>;
    auto casted_x = static_cast<U>(x);

    auto first =
        static_cast<U>(0.5) *
        (static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));

    auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
                  exp(-static_cast<U>(0.5) * casted_x * casted_x);
    return static_cast<T>((first + second));
  }
};

/**
 * @brief dst = dropout(activation(src + bias));
 * the src, mask and dst shape is (rows, cols)
 * the bias shape is (1, cols)
 */
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
template <typename T,
          typename MaskType,
          int VecSize,
          typename Functor,
          typename InType = T,
          typename OutType = T>
__global__ void FusedDropoutActBias(
    Functor act,
    const uint64_t seed,
    const uint64_t rows,
    const uint64_t cols,
    const int increment,
    const float dropout_prob,
    const bool is_upscale_in_train,
    const bool is_test,
    const InType *__restrict__ src,
    const T *__restrict__ bias,
    OutType *dst,
    MaskType *mask,
    const float quant_last_in_scale = 1.0,
    const float *dequant_out_scale_data = nullptr,
    const int quant_out_scale_offset = 0,
    const float quant_next_in_scale = 1.0,
    const int quant_round_type = 1,
    const float quant_max_bound = 127.0,
    const float quant_min_bound = -127.0) {
94 95 96 97 98 99 100
  int col_id = blockDim.x * blockIdx.x + threadIdx.x;
  int row_id = blockIdx.y;
  int idx = row_id * cols + col_id;

  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);

101
  const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
102 103 104 105

  for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
    for (int i = col_id * VecSize; i < cols;
         i += blockDim.x * gridDim.x * VecSize) {
106 107 108 109 110
      FusedResidualDropoutBiasOneThread<T,
                                        MaskType,
                                        VecSize,
                                        false,
                                        true,
111 112 113
                                        Functor,
                                        InType,
                                        OutType>(r,
114 115 116 117 118 119 120 121 122 123 124 125 126
                                                 i,
                                                 cols,
                                                 &state,
                                                 dropout_prob,
                                                 factor,
                                                 src,
                                                 nullptr,
                                                 bias,
                                                 dst,
                                                 mask,
                                                 is_test,
                                                 nullptr,
                                                 nullptr,
127 128 129 130 131 132 133 134
                                                 act,
                                                 quant_last_in_scale,
                                                 dequant_out_scale_data,
                                                 quant_out_scale_offset,
                                                 quant_next_in_scale,
                                                 quant_round_type,
                                                 quant_max_bound,
                                                 quant_min_bound);
135 136 137 138
    }
  }
}

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
template <typename T,
          int VecSize,
          typename Functor,
          typename InType = T,
          typename OutType = T>
__global__ void FusedActBias(Functor act,
                             const uint64_t elem_cnt,
                             const uint64_t cols,
                             const InType *__restrict__ src,
                             const T *__restrict__ bias,
                             OutType *dst) {
  const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
  using LoadT = phi::AlignedVector<T, VecSize>;
  using LoadInType = phi::AlignedVector<InType, VecSize>;
  using LoadFloat = phi::AlignedVector<float, VecSize>;
  using StoreOutType = phi::AlignedVector<OutType, VecSize>;

  LoadInType src_vec;
  LoadT bias_vec;
  StoreOutType out_vec;
  for (int32_t idx = global_thread_idx * VecSize,
               step = blockDim.x * gridDim.x * VecSize;
       idx < elem_cnt;
       idx += step) {
    const int32_t col_idx = idx % cols;
    phi::Load<InType, VecSize>(&src[idx], &src_vec);
    if (bias) {
      phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
    }
#pragma unroll
    for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
      if (bias) {
        out_vec[unroll_idx] = static_cast<OutType>(
            act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
      } else {
        out_vec[unroll_idx] =
            static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
      }
    }
    phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
  }
}

182 183 184
/**
 * @brief dst = dropout(activation(src + bias));
 */
185 186 187 188 189
template <typename T,
          typename MaskType,
          typename Functor,
          typename InType = T,
          typename OutType = T>
190 191 192 193 194 195 196 197
void LaunchDropoutActBias(Functor act_functor,
                          const uint64_t seed,
                          const uint32_t rows,
                          const uint32_t cols,
                          const int increment,
                          const float dropout_prob,
                          const bool is_upscale_in_train,
                          const bool is_test,
198
                          const InType *src,
199
                          const T *bias,
200
                          OutType *dst,
201
                          MaskType *mask_data,
202 203 204 205 206 207 208 209
                          const phi::GPUContext &ctx,
                          const float quant_last_in_scale = 1.0,
                          const float *dequant_out_scale_data = nullptr,
                          const int quant_out_scale_offset = 0,
                          const float quant_next_in_scale = 1.0,
                          const int quant_round_type = 1,
                          const float quant_max_bound = 127.0,
                          const float quant_min_bound = -127.0) {
210 211
  // dropout_prob == 1.0f
  if (std::abs(dropout_prob - 1.0f) < 1e-5) {
212
    SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
213 214 215 216 217 218 219 220
    SetZero<MaskType>(ctx, mask_data, rows * cols);
    return;
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
  if (cols % VecSize == 0) {
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
    if (is_test && (dequant_out_scale_data == nullptr)) {
      const int32_t elem_cnt = rows * cols;
      const int32_t pack_num = elem_cnt / VecSize;
      const int32_t tmp_cols = cols / VecSize;
      int block_size =
          std::max(static_cast<int32_t>(32), std::min(tmp_cols, 128));
      const int grid_size = std::max(static_cast<int32_t>(1),
                                     (pack_num + block_size - 1) / block_size);
      FusedActBias<T, VecSize, Functor, InType, OutType>
          <<<grid_size, block_size, 0, ctx.stream()>>>(
              act_functor, elem_cnt, cols, src, bias, dst);
    } else {
      FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              act_functor,
              seed,
              rows,
              cols,
              increment,
              dropout_prob,
              is_upscale_in_train,
              is_test,
              src,
              bias,
              dst,
              mask_data,
              quant_last_in_scale,
              dequant_out_scale_data,
              quant_out_scale_offset,
              quant_next_in_scale);
    }
252
  } else {
253
    FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
254
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
255 256 257 258 259 260 261 262 263 264 265
            act_functor,
            seed,
            rows,
            cols,
            increment,
            dropout_prob,
            is_upscale_in_train,
            is_test,
            src,
            bias,
            dst,
266 267 268 269 270
            mask_data,
            quant_last_in_scale,
            dequant_out_scale_data,
            quant_out_scale_offset,
            quant_next_in_scale);
271 272 273 274 275 276 277
  }
}

/*
 * @brief calculate the grad of no bias
 */
template <typename T, typename MaskType, int VecSize, typename Functor>
278 279 280 281 282 283 284
__global__ void FusedDropoutActGrad(Functor act_grad,
                                    const T *dout,
                                    const MaskType *mask,
                                    const T *src,
                                    const T factor,
                                    const int64_t size,
                                    T *dx) {
285 286
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

287 288 289
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
290 291 292 293 294
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_vec;
    LoadT src_vec;
    MaskLoadT mask_vec;

295 296 297
    phi::Load<T, VecSize>(&dout[i], &dout_vec);
    phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
    phi::Load<T, VecSize>(&src[i], &src_vec);
298 299 300 301

    StoreT dx_vec;
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
302 303
      T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
      dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
304
    }
305
    phi::Store<T, VecSize>(dx_vec, &dx[i]);
306 307 308 309 310 311 312 313 314
  }
}

/**
 * blocks(128 * 8)
 * 1. calculate the dx and reduce total rows to 128 rows
 * 2. save 128*8 temporary sum in 8*128 shared memory
 * 3. reduce the sum of 128 cols data by 8*VecSize warps
 */
315 316 317 318 319
template <typename T,
          typename MaskType,
          int BlockSizeX,
          int BlockSizeY,
          int VecSize,
S
Shijie 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332
          typename Functor,
          int THREADS_PER_CTA = BlockSizeX *BlockSizeY>
__global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
    Functor act_grad,
    const T *dout,
    const MaskType *mask,
    const T *src,
    const T *bias,
    const T factor,
    const int64_t rows,
    const int64_t cols,
    T *dx,
    T *dbias) {
333 334
  int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

335 336 337
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
338 339 340 341 342 343 344 345 346 347
  T tmp_sum[VecSize] = {static_cast<T>(0)};
  // calculate the dx and temporary sum
  if (col_id * VecSize < cols) {
    for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
      int index = row_id * cols + col_id * VecSize;
      LoadT dout_vec;
      LoadT src_vec;
      LoadT bias_vec;
      MaskLoadT mask_vec;

348 349 350 351
      phi::Load<T, VecSize>(&dout[index], &dout_vec);
      phi::Load<T, VecSize>(&src[index], &src_vec);
      phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
      phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
352 353 354 355 356

      StoreT dx_vec;
#pragma unroll
      for (int i = 0; i < VecSize; i++) {
        T val;
357 358
        T tmp = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
        val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]);
359 360 361
        dx_vec[i] = val;
        tmp_sum[i] += val;
      }
362
      phi::Store<T, VecSize>(dx_vec, &dx[index]);
363 364 365 366 367 368 369 370 371 372
    }
  }

  CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}

/**
 * @brief to launch kernel FusedResidualDropoutBiasGradVec
 */
template <typename T, typename MaskType, typename Functor>
373 374 375 376 377
void LaunchDropoutActBiasGrad(Functor act_functor,
                              const T *dout,
                              const MaskType *mask,
                              const T *src,
                              const T *bias,
378 379
                              const float dropout_prob,
                              const bool is_upscale_in_train,
380 381 382
                              const uint32_t rows,
                              const uint32_t cols,
                              T *dx,
383
                              T *dbias,
L
Leo Chen 已提交
384
                              const phi::GPUContext &ctx) {
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
  const T zero = static_cast<T>(0.0);
  auto factor = dropout_prob == static_cast<float>(1.0f)
                    ? zero
                    : static_cast<T>(1.0 / (1.0 - dropout_prob));
  if (!is_upscale_in_train) {
    factor = static_cast<T>(1.0f);
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  int real_vec_size = cols % VecSize == 0 ? VecSize : 1;

  if (dbias != nullptr) {
    const auto threads = 8;
    const auto blocks =
        std::max(static_cast<uint32_t>(1),
                 (cols / real_vec_size + threads - 1) / threads);
    dim3 block_dim(threads, 128, 1);
    dim3 grid_dim(blocks, 1, 1);
    if (cols % VecSize == 0) {
404
      FusedDropoutActBiasGrad<T, MaskType, 8, 128, VecSize, Functor>
405 406 407 408 409 410 411 412 413 414
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
                                                     dout,
                                                     mask,
                                                     src,
                                                     bias,
                                                     factor,
                                                     rows,
                                                     cols,
                                                     dx,
                                                     dbias);
415
    } else {
416
      FusedDropoutActBiasGrad<T, MaskType, 8, 128, 1, Functor>
417 418 419 420 421 422 423 424 425 426
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
                                                     dout,
                                                     mask,
                                                     src,
                                                     bias,
                                                     factor,
                                                     rows,
                                                     cols,
                                                     dx,
                                                     dbias);
427 428 429 430 431 432
    }
  } else {
    const uint64_t n = rows * cols;
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
    if (n % VecSize == 0) {
433 434 435
      FusedDropoutActGrad<T, MaskType, VecSize, Functor>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              act_functor, dout, mask, src, factor, n, dx);
436
    } else {
437 438 439
      FusedDropoutActGrad<T, MaskType, 1, Functor>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              act_functor, dout, mask, src, factor, n, dx);
440 441 442 443 444 445
    }
  }
}

}  // namespace operators
}  // namespace paddle