// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "glog/logging.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" PHI_DECLARE_bool(use_fast_math); namespace phi { namespace fusion { #ifndef PADDLE_WITH_HIP template __global__ void ActFFNGlu(const T *bias, Functor act_functor, const int token_num, const int hid_dim, const int elem_num, LoadFunc load_func, StoreFunc store_func) { using LoadT = phi::AlignedVector; LoadT src_vec1; LoadT src_vec2; LoadT bias_vec1; LoadT bias_vec2; const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = global_tid * VecSize; i < elem_num; i += gridDim.x * blockDim.x * VecSize) { int bi = i / hid_dim; int idx = i % hid_dim; load_func.template load(&src_vec1, bi * hid_dim * 2 + idx); load_func.template load(&src_vec2, bi * hid_dim * 2 + idx + hid_dim); if (bias) { phi::Load(&bias[idx], &bias_vec1); phi::Load(&bias[idx + hid_dim], &bias_vec2); } #pragma unroll for (int j = 0; j < VecSize; j++) { if (bias) { src_vec1[j] += bias_vec1[j]; src_vec2[j] += bias_vec2[j]; } src_vec1[j] = act_functor(src_vec1[j]); src_vec1[j] *= src_vec2[j]; } store_func.template store(src_vec1, bi * hid_dim + idx); } } template void LaunchActFFNGlu(const Context &dev_ctx, const T *bias, const int token_num, const int hid_dim, LoadFunc load_func, StoreFunc store_func) { constexpr int VecSize = 16; constexpr int PackSize = VecSize / sizeof(LoadT); const int elem_cnt = token_num * hid_dim; const int blocksize = 128; int grid_size = 1; Functor functor; switch (hid_dim % PackSize) { case 0: GetNumBlocks(elem_cnt / PackSize, &grid_size); ActFFNGlu <<>>(bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); break; default: GetNumBlocks(elem_cnt, &grid_size); ActFFNGlu<<>>( bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); break; } } template __global__ void BiasAct(const T *bias, Functor act_functor, const int rows, const int cols, const int elem_num, LoadFunc load_func, StoreFunc store_func) { using LoadT = phi::AlignedVector; LoadT src_vec; LoadT bias_vec; // Zero Initialize BiasVec. #pragma unroll for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { bias_vec[unroll_idx] = 0; } const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = global_tid * VecSize; i < elem_num; i += gridDim.x * blockDim.x * VecSize) { int row_idx = i / cols; int col_idx = i % cols; int linear_idx = row_idx * cols + col_idx; load_func.template load(&src_vec, linear_idx); if (bias) { phi::Load(&bias[col_idx], &bias_vec); } #pragma unroll for (int j = 0; j < VecSize; j++) { if (bias) { src_vec[j] += bias_vec[j]; } src_vec[j] = act_functor(src_vec[j]); } store_func.template store(src_vec, linear_idx); } } template void LaunchBiasAct(const Context &dev_ctx, const T *bias, const int token_num, const int hid_dim, LoadFunc load_func, StoreFunc store_func) { constexpr int VecSize = 16; constexpr int PackSize = VecSize / sizeof(LoadT); const int elem_cnt = token_num * hid_dim; const int blocksize = 128; int grid_size = 1; Functor functor; switch (hid_dim % PackSize) { case 0: GetNumBlocks(elem_cnt / PackSize, &grid_size); BiasAct <<>>(bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); break; default: GetNumBlocks(elem_cnt, &grid_size); BiasAct<<>>( bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); break; } } template void ComputeImpl(const Context &dev_ctx, const T *bias_data, const std::string &act_method, int rows, int cols, LoadFunc load_func, StoreFunc store_func) { if (act_method == "geglu") { // Note(Zhengzekang): For GLU structure, we need divide the cols by 2. VLOG(8) << "Doing geglu"; LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( dev_ctx, bias_data, rows, cols / 2, load_func, store_func); } else if (act_method == "swiglu") { VLOG(8) << "Doing swiglu"; LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( dev_ctx, bias_data, rows, cols / 2, load_func, store_func); } else if (act_method == "gelu") { if (FLAGS_use_fast_math) { VLOG(8) << "Doing Fast GELU"; LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( dev_ctx, bias_data, rows, cols, load_func, store_func); } else { VLOG(8) << "Doing GELU"; LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( dev_ctx, bias_data, rows, cols, load_func, store_func); } } else { PADDLE_THROW(phi::errors::Unimplemented( "Currently Only Support GeGLU, SwiGLU, GeLU")); } } template void DispatchComputeImpl(const Context &dev_ctx, const DenseTensor &x, const DenseTensor *bias, const DenseTensor *dequant_scales, const std::string &act_method, int rows, int cols, const float quant_scale, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, DenseTensor *out) { const T *bias_data = bias == nullptr ? nullptr : bias->data(); if (dequant_scales != nullptr && quant_scale > 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); QuantStore store_func(dev_ctx.template Alloc(out), quant_round_type, quant_scale, quant_max_bound, quant_min_bound); ComputeImpl, QuantStore, int32_t>( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales == nullptr && quant_scale > 0) { Load load_func(x.data()); QuantStore store_func(dev_ctx.template Alloc(out), quant_round_type, quant_scale, quant_max_bound, quant_min_bound); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales != nullptr && quant_scale <= 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); Store store_func(dev_ctx.template Alloc(out)); ComputeImpl, Store, int32_t>( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else { Load load_func(x.data()); Store store_func(dev_ctx.template Alloc(out)); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } } template void DispatchComputeImpl(const Context &dev_ctx, const DenseTensor &x, const DenseTensor *bias, const DenseTensor *dequant_scales, const DenseTensor *shift, const DenseTensor *smooth, const std::string &act_method, int rows, int cols, const float quant_scale, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, DenseTensor *out) { bool use_glu = (act_method == "geglu" || act_method == "swiglu"); const T *bias_data = bias == nullptr ? nullptr : bias->data(); if (dequant_scales != nullptr && quant_scale > 0) { int8_t *out_data = dev_ctx.template Alloc(out); DequantLoad load_func( x.data(), dequant_scales->data(), cols); QuantStore store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols, quant_round_type, quant_scale, quant_max_bound, quant_min_bound); ComputeImpl, QuantStore, int32_t>( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales == nullptr && quant_scale > 0) { Load load_func(x.data()); QuantStore store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols, quant_round_type, quant_scale, quant_max_bound, quant_min_bound); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales != nullptr && quant_scale <= 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); Store store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols); ComputeImpl, Store, int32_t>( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else { Load load_func(x.data()); Store store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } } struct NormalVersion {}; struct UnusedVersion {}; template struct DispatchDtypeTrait { using FuncVersion = NormalVersion; }; template <> struct DispatchDtypeTrait { using FuncVersion = UnusedVersion; }; template void DispatchWithDtype(const Context &dev_ctx, const DenseTensor &x, const paddle::optional &bias, const paddle::optional &dequant_scales, const paddle::optional &shift, const paddle::optional &smooth, const std::string &act_method, int rows, int cols, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound, DenseTensor *out, NormalVersion) { auto *bias_p = bias.get_ptr(); auto *dequant_scales_p = dequant_scales.get_ptr(); auto *shift_p = shift.get_ptr(); auto *smooth_p = smooth.get_ptr(); if (dequant_scales_p != nullptr) { if (shift_p != nullptr) { DispatchComputeImpl(dev_ctx, x, bias_p, dequant_scales_p, shift_p, smooth_p, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out); } else { DispatchComputeImpl(dev_ctx, x, bias_p, dequant_scales_p, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out); } } else { const T *bias_data = bias_p == nullptr ? nullptr : bias_p->data(); Load load_func(x.data()); Store store_func(dev_ctx.template Alloc(out)); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } } // (not use) only for registering int32_t template void DispatchWithDtype(const Context &dev_ctx, const DenseTensor &x, const paddle::optional &bias, const paddle::optional &dequant_scales, const paddle::optional &shift, const paddle::optional &smooth, const std::string &act_method, int rows, int cols, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound, DenseTensor *out, UnusedVersion) {} #endif template void FusedBiasActKernel(const Context &dev_ctx, const DenseTensor &x, const paddle::optional &bias, const paddle::optional &dequant_scales, const paddle::optional &shift, const paddle::optional &smooth, const std::string &act_method, const std::string &compute_dtype, int rows, int cols, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound, DenseTensor *out) { #ifndef PADDLE_WITH_HIP if (x.dtype() == phi::DataType::INT32) { if (compute_dtype == "bf16") { DispatchWithDtype( dev_ctx, x, bias, dequant_scales, shift, smooth, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out, typename DispatchDtypeTrait::FuncVersion{}); } else if (compute_dtype == "fp16") { DispatchWithDtype( dev_ctx, x, bias, dequant_scales, shift, smooth, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out, typename DispatchDtypeTrait::FuncVersion{}); } else if (compute_dtype == "fp32") { DispatchWithDtype( dev_ctx, x, bias, dequant_scales, shift, smooth, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out, typename DispatchDtypeTrait::FuncVersion{}); } else { PADDLE_THROW(phi::errors::InvalidArgument( "In the case of quantization enabled with Input(x) INT32, " "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " "but get compute_dtype (%s)", compute_dtype)); } } else { DispatchWithDtype( dev_ctx, x, bias, dequant_scales, shift, smooth, act_method, rows, cols, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, out, typename DispatchDtypeTrait::FuncVersion{}); } #endif } } // namespace fusion } // namespace phi PD_REGISTER_KERNEL(fused_bias_act, GPU, ALL_LAYOUT, phi::fusion::FusedBiasActKernel, float, phi::dtype::bfloat16, phi::dtype::float16, int32_t) {}