fused_dropout_act_bias.h 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

20
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

namespace paddle {
namespace operators {

/**
 *@brief the gelu functor
 */
template <typename T>
struct GeluFunctor {
  inline __host__ __device__ T operator()(const T x) const {
    using U = LayerNormParamType<T>;
    const U casted_x = static_cast<U>(x);
    const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
    const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
    return static_cast<T>(out);
  }
};

/**
 *@brief the gelu grad functor
 */
template <typename T>
struct GeluGradFunctor {
  inline __host__ __device__ T UseOut(const T x) const {
    using U = LayerNormParamType<T>;
    auto casted_x = static_cast<U>(x);

    auto first =
        static_cast<U>(0.5) *
        (static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));

    auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
                  exp(-static_cast<U>(0.5) * casted_x * casted_x);
    return static_cast<T>((first + second));
  }
};

/**
 * @brief dst = dropout(activation(src + bias));
 * the src, mask and dst shape is (rows, cols)
 * the bias shape is (1, cols)
 */
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
template <typename T,
          typename MaskType,
          int VecSize,
          typename Functor,
          typename InType = T,
          typename OutType = T>
__global__ void FusedDropoutActBias(
    Functor act,
    const uint64_t seed,
    const uint64_t rows,
    const uint64_t cols,
    const int increment,
    const float dropout_prob,
    const bool is_upscale_in_train,
    const bool is_test,
    const InType *__restrict__ src,
    const T *__restrict__ bias,
    OutType *dst,
    MaskType *mask,
    const float quant_last_in_scale = 1.0,
    const float *dequant_out_scale_data = nullptr,
    const int quant_out_scale_offset = 0,
    const float quant_next_in_scale = 1.0,
    const int quant_round_type = 1,
    const float quant_max_bound = 127.0,
    const float quant_min_bound = -127.0) {
89 90 91 92 93 94 95
  int col_id = blockDim.x * blockIdx.x + threadIdx.x;
  int row_id = blockIdx.y;
  int idx = row_id * cols + col_id;

  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);

96
  const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
97 98 99 100

  for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
    for (int i = col_id * VecSize; i < cols;
         i += blockDim.x * gridDim.x * VecSize) {
101 102 103 104 105
      FusedResidualDropoutBiasOneThread<T,
                                        MaskType,
                                        VecSize,
                                        false,
                                        true,
106 107 108
                                        Functor,
                                        InType,
                                        OutType>(r,
109 110 111 112 113 114 115 116 117 118 119 120 121
                                                 i,
                                                 cols,
                                                 &state,
                                                 dropout_prob,
                                                 factor,
                                                 src,
                                                 nullptr,
                                                 bias,
                                                 dst,
                                                 mask,
                                                 is_test,
                                                 nullptr,
                                                 nullptr,
122 123 124 125 126 127 128 129
                                                 act,
                                                 quant_last_in_scale,
                                                 dequant_out_scale_data,
                                                 quant_out_scale_offset,
                                                 quant_next_in_scale,
                                                 quant_round_type,
                                                 quant_max_bound,
                                                 quant_min_bound);
130 131 132 133 134 135 136
    }
  }
}

/**
 * @brief dst = dropout(activation(src + bias));
 */
137 138 139 140 141
template <typename T,
          typename MaskType,
          typename Functor,
          typename InType = T,
          typename OutType = T>
142 143 144 145 146 147 148 149
void LaunchDropoutActBias(Functor act_functor,
                          const uint64_t seed,
                          const uint32_t rows,
                          const uint32_t cols,
                          const int increment,
                          const float dropout_prob,
                          const bool is_upscale_in_train,
                          const bool is_test,
150
                          const InType *src,
151
                          const T *bias,
152
                          OutType *dst,
153
                          MaskType *mask_data,
154 155 156 157 158 159 160 161
                          const phi::GPUContext &ctx,
                          const float quant_last_in_scale = 1.0,
                          const float *dequant_out_scale_data = nullptr,
                          const int quant_out_scale_offset = 0,
                          const float quant_next_in_scale = 1.0,
                          const int quant_round_type = 1,
                          const float quant_max_bound = 127.0,
                          const float quant_min_bound = -127.0) {
162 163
  // dropout_prob == 1.0f
  if (std::abs(dropout_prob - 1.0f) < 1e-5) {
164
    SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
165 166 167 168 169 170 171 172
    SetZero<MaskType>(ctx, mask_data, rows * cols);
    return;
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
  const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
  if (cols % VecSize == 0) {
173
    FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
174
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
175 176 177 178 179 180 181 182 183 184 185
            act_functor,
            seed,
            rows,
            cols,
            increment,
            dropout_prob,
            is_upscale_in_train,
            is_test,
            src,
            bias,
            dst,
186 187 188 189 190
            mask_data,
            quant_last_in_scale,
            dequant_out_scale_data,
            quant_out_scale_offset,
            quant_next_in_scale);
191
  } else {
192
    FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
193
        <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
194 195 196 197 198 199 200 201 202 203 204
            act_functor,
            seed,
            rows,
            cols,
            increment,
            dropout_prob,
            is_upscale_in_train,
            is_test,
            src,
            bias,
            dst,
205 206 207 208 209
            mask_data,
            quant_last_in_scale,
            dequant_out_scale_data,
            quant_out_scale_offset,
            quant_next_in_scale);
210 211 212 213 214 215 216
  }
}

/*
 * @brief calculate the grad of no bias
 */
template <typename T, typename MaskType, int VecSize, typename Functor>
217 218 219 220 221 222 223
__global__ void FusedDropoutActGrad(Functor act_grad,
                                    const T *dout,
                                    const MaskType *mask,
                                    const T *src,
                                    const T factor,
                                    const int64_t size,
                                    T *dx) {
224 225
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

226 227 228
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
229 230 231 232 233
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_vec;
    LoadT src_vec;
    MaskLoadT mask_vec;

234 235 236
    phi::Load<T, VecSize>(&dout[i], &dout_vec);
    phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
    phi::Load<T, VecSize>(&src[i], &src_vec);
237 238 239 240

    StoreT dx_vec;
#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
241 242
      T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
      dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
243
    }
244
    phi::Store<T, VecSize>(dx_vec, &dx[i]);
245 246 247 248 249 250 251 252 253
  }
}

/**
 * blocks(128 * 8)
 * 1. calculate the dx and reduce total rows to 128 rows
 * 2. save 128*8 temporary sum in 8*128 shared memory
 * 3. reduce the sum of 128 cols data by 8*VecSize warps
 */
254 255 256 257 258
template <typename T,
          typename MaskType,
          int BlockSizeX,
          int BlockSizeY,
          int VecSize,
259 260 261 262 263 264 265 266 267 268 269 270 271
          typename Functor,
          int THREADS_PER_CTA = BlockSizeX *BlockSizeY>
__global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
    Functor act_grad,
    const T *dout,
    const MaskType *mask,
    const T *src,
    const T *bias,
    const T factor,
    const int64_t rows,
    const int64_t cols,
    T *dx,
    T *dbias) {
272 273
  int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

274 275 276
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
277 278 279 280 281 282 283 284 285 286
  T tmp_sum[VecSize] = {static_cast<T>(0)};
  // calculate the dx and temporary sum
  if (col_id * VecSize < cols) {
    for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
      int index = row_id * cols + col_id * VecSize;
      LoadT dout_vec;
      LoadT src_vec;
      LoadT bias_vec;
      MaskLoadT mask_vec;

287 288 289 290
      phi::Load<T, VecSize>(&dout[index], &dout_vec);
      phi::Load<T, VecSize>(&src[index], &src_vec);
      phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
      phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
291 292 293 294 295

      StoreT dx_vec;
#pragma unroll
      for (int i = 0; i < VecSize; i++) {
        T val;
296 297
        T tmp = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
        val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]);
298 299 300
        dx_vec[i] = val;
        tmp_sum[i] += val;
      }
301
      phi::Store<T, VecSize>(dx_vec, &dx[index]);
302 303 304 305 306 307 308 309 310 311
    }
  }

  CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}

/**
 * @brief to launch kernel FusedResidualDropoutBiasGradVec
 */
template <typename T, typename MaskType, typename Functor>
312 313 314 315 316
void LaunchDropoutActBiasGrad(Functor act_functor,
                              const T *dout,
                              const MaskType *mask,
                              const T *src,
                              const T *bias,
317 318
                              const float dropout_prob,
                              const bool is_upscale_in_train,
319 320 321
                              const uint32_t rows,
                              const uint32_t cols,
                              T *dx,
322
                              T *dbias,
L
Leo Chen 已提交
323
                              const phi::GPUContext &ctx) {
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
  const T zero = static_cast<T>(0.0);
  auto factor = dropout_prob == static_cast<float>(1.0f)
                    ? zero
                    : static_cast<T>(1.0 / (1.0 - dropout_prob));
  if (!is_upscale_in_train) {
    factor = static_cast<T>(1.0f);
  }

  const int VecSize = MAX_CACHE_BYTES / sizeof(T);
  int real_vec_size = cols % VecSize == 0 ? VecSize : 1;

  if (dbias != nullptr) {
    const auto threads = 8;
    const auto blocks =
        std::max(static_cast<uint32_t>(1),
                 (cols / real_vec_size + threads - 1) / threads);
    dim3 block_dim(threads, 128, 1);
    dim3 grid_dim(blocks, 1, 1);
    if (cols % VecSize == 0) {
343
      FusedDropoutActBiasGrad<T, MaskType, 8, 128, VecSize, Functor>
344 345 346 347 348 349 350 351 352 353
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
                                                     dout,
                                                     mask,
                                                     src,
                                                     bias,
                                                     factor,
                                                     rows,
                                                     cols,
                                                     dx,
                                                     dbias);
354
    } else {
355
      FusedDropoutActBiasGrad<T, MaskType, 8, 128, 1, Functor>
356 357 358 359 360 361 362 363 364 365
          <<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
                                                     dout,
                                                     mask,
                                                     src,
                                                     bias,
                                                     factor,
                                                     rows,
                                                     cols,
                                                     dx,
                                                     dbias);
366 367 368 369 370 371
    }
  } else {
    const uint64_t n = rows * cols;
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
    if (n % VecSize == 0) {
372 373 374
      FusedDropoutActGrad<T, MaskType, VecSize, Functor>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              act_functor, dout, mask, src, factor, n, dx);
375
    } else {
376 377 378
      FusedDropoutActGrad<T, MaskType, 1, Functor>
          <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
              act_functor, dout, mask, src, factor, n, dx);
379 380 381 382 383 384
    }
  }
}

}  // namespace operators
}  // namespace paddle