/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" namespace paddle { namespace operators { /** *@brief the gelu functor */ template struct GeluFunctor { inline __host__ __device__ T operator()(const T x) const { using U = LayerNormParamType; const U casted_x = static_cast(x); const U temp = erf(casted_x * static_cast(M_SQRT1_2)); const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); return static_cast(out); } }; /** *@brief the gelu grad functor */ template struct GeluGradFunctor { inline __host__ __device__ T UseOut(const T x) const { using U = LayerNormParamType; auto casted_x = static_cast(x); auto first = static_cast(0.5) * (static_cast(1) + erf(casted_x * static_cast(M_SQRT1_2))); auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x * exp(-static_cast(0.5) * casted_x * casted_x); return static_cast((first + second)); } }; /** * @brief dst = dropout(activation(src + bias)); * the src, mask and dst shape is (rows, cols) * the bias shape is (1, cols) */ template __global__ void FusedDropoutActBias( Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols, const int increment, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, const InType *__restrict__ src, const T *__restrict__ bias, OutType *dst, MaskType *mask, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { FusedResidualDropoutBiasOneThread(r, i, cols, &state, dropout_prob, factor, src, nullptr, bias, dst, mask, is_test, nullptr, nullptr, act, quant_last_in_scale, dequant_out_scale_data, quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, quant_min_bound); } } } /** * @brief dst = dropout(activation(src + bias)); */ template void LaunchDropoutActBias(Functor act_functor, const uint64_t seed, const uint32_t rows, const uint32_t cols, const int increment, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, const InType *src, const T *bias, OutType *dst, MaskType *mask_data, const phi::GPUContext &ctx, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { SetZero(ctx, reinterpret_cast(dst), rows * cols); SetZero(ctx, mask_data, rows * cols); return; } const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); if (cols % VecSize == 0) { FusedDropoutActBias <<>>( act_functor, seed, rows, cols, increment, dropout_prob, is_upscale_in_train, is_test, src, bias, dst, mask_data, quant_last_in_scale, dequant_out_scale_data, quant_out_scale_offset, quant_next_in_scale); } else { FusedDropoutActBias <<>>( act_functor, seed, rows, cols, increment, dropout_prob, is_upscale_in_train, is_test, src, bias, dst, mask_data, quant_last_in_scale, dequant_out_scale_data, quant_out_scale_offset, quant_next_in_scale); } } /* * @brief calculate the grad of no bias */ template __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, const MaskType *mask, const T *src, const T factor, const int64_t size, T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; using MaskLoadT = phi::AlignedVector; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { LoadT dout_vec; LoadT src_vec; MaskLoadT mask_vec; phi::Load(&dout[i], &dout_vec); phi::Load(&mask[i], &mask_vec); phi::Load(&src[i], &src_vec); StoreT dx_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { T tmp = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]); } phi::Store(dx_vec, &dx[i]); } } /** * blocks(128 * 8) * 1. calculate the dx and reduce total rows to 128 rows * 2. save 128*8 temporary sum in 8*128 shared memory * 3. reduce the sum of 128 cols data by 8*VecSize warps */ template __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, const MaskType *mask, const T *src, const T *bias, const T factor, const int64_t rows, const int64_t cols, T *dx, T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; using MaskLoadT = phi::AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum if (col_id * VecSize < cols) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { int index = row_id * cols + col_id * VecSize; LoadT dout_vec; LoadT src_vec; LoadT bias_vec; MaskLoadT mask_vec; phi::Load(&dout[index], &dout_vec); phi::Load(&src[index], &src_vec); phi::Load(&mask[index], &mask_vec); phi::Load(&bias[col_id * VecSize], &bias_vec); StoreT dx_vec; #pragma unroll for (int i = 0; i < VecSize; i++) { T val; T tmp = dout_vec[i] * static_cast(mask_vec[i]) * factor; val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]); dx_vec[i] = val; tmp_sum[i] += val; } phi::Store(dx_vec, &dx[index]); } } CalculateDBias(tmp_sum, dbias, cols); } /** * @brief to launch kernel FusedResidualDropoutBiasGradVec */ template void LaunchDropoutActBiasGrad(Functor act_functor, const T *dout, const MaskType *mask, const T *src, const T *bias, const float dropout_prob, const bool is_upscale_in_train, const uint32_t rows, const uint32_t cols, T *dx, T *dbias, const phi::GPUContext &ctx) { const T zero = static_cast(0.0); auto factor = dropout_prob == static_cast(1.0f) ? zero : static_cast(1.0 / (1.0 - dropout_prob)); if (!is_upscale_in_train) { factor = static_cast(1.0f); } const int VecSize = MAX_CACHE_BYTES / sizeof(T); int real_vec_size = cols % VecSize == 0 ? VecSize : 1; if (dbias != nullptr) { const auto threads = 8; const auto blocks = std::max(static_cast(1), (cols / real_vec_size + threads - 1) / threads); dim3 block_dim(threads, 128, 1); dim3 grid_dim(blocks, 1, 1); if (cols % VecSize == 0) { FusedDropoutActBiasGrad <<>>(act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); } else { FusedDropoutActBiasGrad <<>>(act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); } } else { const uint64_t n = rows * cols; platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); if (n % VecSize == 0) { FusedDropoutActGrad <<>>( act_functor, dout, mask, src, factor, n, dx); } else { FusedDropoutActGrad <<>>( act_functor, dout, mask, src, factor, n, dx); } } } } // namespace operators } // namespace paddle