未验证 提交 a7ec8958 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移前向 GPU OpKernel] (#51743)

* add kernel functions

* update kernel functions

* update func parameters' name

* create codes for gpu device

* 调整文件位置

* fix include error

* remove dependent files to phi/

* restore fused_attention_op.cu

* fix dependence errors

* fix dependence errors

* fix include error

* fix all depandence errors[build success]

* remove useless include

* recover useless include

* use phi::ToNCCLDataType

* fix namespace

* update new register code

* fix error in fused_gemm_epilogue_utils

* fix error in FusedAttentionKernel parm

* finish fused_attention registe code[build success]

* add paddle::optional

* add sig file

* fix build error

* fix a include error

* update CMkaeList

* fix parameter sequence

* add include file

* update #if before include

* fix grammly error

* update codes for DropoutParam

* remove const cast

* trans some fluid api to phi api

* add #if

* update test code

* update test codes

* recover test codes

* trans fused_attention to fluid

* move #endif to end

* move #endif

* delete useless files

* use fused attention utils and recover random seed

* remove fluid include in phi
上级 6df4a667
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attention_layer_norm.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
...@@ -32,377 +33,21 @@ limitations under the License. */ ...@@ -32,377 +33,21 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) // for phi fused attention
#include "paddle/fluid/distributed/collective/process_group_nccl.h" // fluid include will be removed after fused attention grad kernel is merged
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/core/enforce.h"
#endif #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
template <typename T, typename DeviceContext>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<phi::DenseTensor>("X");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
auto *ln_bias = ctx.Input<phi::DenseTensor>("LnBias");
auto *ln_mean = ctx.Output<phi::DenseTensor>("LnMean");
auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
const auto num_heads = ctx.Attr<int>("num_heads");
const auto transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
// x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
// if transpose_qkv_wb is True
// y: qkv's weight: [dim_embed, 3 * dim_embed]
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
auto *qkv_bias_out = ctx.Output<phi::DenseTensor>("QKVBiasOut");
auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
auto *transpose_out_2 = ctx.Output<phi::DenseTensor>("TransposeOut2");
auto *cache_kv = ctx.Input<phi::DenseTensor>("CacheKV");
auto *cache_kv_out = ctx.Output<phi::DenseTensor>("CacheKVOut");
auto *qk_out = ctx.Output<phi::DenseTensor>("QKOut");
auto *qktv_out = ctx.Output<phi::DenseTensor>("QKTVOut");
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
auto *attn_dropout_mask_out =
ctx.Output<phi::DenseTensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Output<phi::DenseTensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Output<phi::DenseTensor>("SrcMaskOut");
auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
auto *out_linear_out = ctx.Output<phi::DenseTensor>("OutLinearOut");
auto *ln_scale_2 = ctx.Input<phi::DenseTensor>("Ln2Scale");
auto *ln_bias_2 = ctx.Input<phi::DenseTensor>("Ln2Bias");
auto *dropout_mask_out = ctx.Output<phi::DenseTensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Output<phi::DenseTensor>("BiasDropoutResidualOut");
auto *ln_mean_2 = ctx.Output<phi::DenseTensor>("Ln2Mean");
auto *ln_var_2 = ctx.Output<phi::DenseTensor>("Ln2Variance");
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 =
ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
int ring_id = ctx.Attr<int>("ring_id");
// final output.
auto *out = ctx.Output<phi::DenseTensor>("Y");
// get data ptr for qkv part.
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
auto *x_data = input_x->data<T>();
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *qkv_out_data =
dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
auto *qkv_bias_out_data =
(qkv_bias == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_bias_out,
qkv_bias_out->numel() * sizeof(T));
// get data ptr for FMHA.
auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
transpose_out_2, transpose_out_2->numel() * sizeof(T));
auto *cache_kv_out_data =
(cache_kv_out == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(cache_kv_out,
cache_kv_out->numel() * sizeof(T));
auto *qk_out_data =
dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
auto *qktv_out_data =
dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
auto *src_mask_out_data =
(src_mask == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(src_mask_out,
src_mask_out->numel() * sizeof(T));
auto *softmax_out_data = dev_ctx.template Alloc<T>(
softmax_out, softmax_out->numel() * sizeof(T));
auto *attn_dropout_mask_out_data =
has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
attn_dropout_mask_out,
attn_dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(attn_dropout_out,
attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *fmha_out_data =
dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
auto *out_linear_out_data = dev_ctx.template Alloc<T>(
out_linear_out, out_linear_out->numel() * sizeof(T));
// get data ptr for bias+dropout+residual+layernorm
auto *dropout_mask_out_data =
has_dropout
? dev_ctx.template Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *final_out_data =
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
auto layer_norm_compute = AttnLayerNorm<T>(
ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
bool compute_bias = true;
if (qkv_bias == nullptr) {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true)
bool transB = transpose_qkv_wb ? false : true;
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
false,
transB,
bsz_seq,
output_size,
input_size,
compute_bias);
AttnDropoutParam attn_dropout_param(is_test_1,
dropout_implementation_1,
attn_dropout_rate,
is_upscale_in_train_1,
is_fix_seed_1,
seed_val_1,
seed_1);
auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
batch_size,
max_seq_len,
num_head,
dim_head,
attn_dropout_param);
output_size = hidden_size;
// (transA, transB, compute_bias) = (false, false, false)
// NOTE(Yuang Liu): For general input size == output size, change the
// position won't have effects. For mp, the output size is mp_head * dkey
// which is actually the input size. While the input size is hidden size,
// which is actually the output size. So for out linear, switch the
// input size and output size.
auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
false,
false,
bsz_seq,
input_size,
output_size,
false);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(),
bsz_seq,
dim_embed,
dropout_param2,
ln_epsilon);
if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data =
dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
auto *ln_var_data =
dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
auto *ln_out_data =
dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
layer_norm_compute.ComputeForward(x_data,
ln_scale_data,
ln_bias_data,
ln_out_data,
ln_mean_data,
ln_var_data);
qkv_compute.ComputeForward(
qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
} else {
qkv_compute.ComputeForward(
qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
}
if (transpose_qkv_wb) {
// resize the output for fmha compute
qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out,
cache_kv,
src_mask,
transpose_out_2,
cache_kv_out,
qk_out,
src_mask_out,
softmax_out,
attn_dropout_mask_out,
attn_dropout_out,
qktv_out,
fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out,
cache_kv,
src_mask,
transpose_out_2,
cache_kv_out,
qk_out,
src_mask_out,
softmax_out,
attn_dropout_mask_out,
attn_dropout_out,
qktv_out,
fmha_out);
}
if (transpose_qkv_wb) {
// resize the output back to make the shape compatible with infer shape
qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(
out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
// tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());
bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(),
out_linear_out_data,
residual_ptr,
out_linear_bias_data,
final_out_data,
dropout_mask_out_data);
} else {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual,
true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
bias_dropout_residual_out,
bias_dropout_residual_out->numel() * sizeof(T));
U *ln_mean_2_ptr =
dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
U *ln_var_2_ptr =
dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(),
out_linear_out_data,
residual_ptr,
out_linear_bias_data,
ln_scale_2_ptr,
ln_bias_2_ptr,
bias_dropout_residual_out_ptr,
dropout_mask_out_data,
final_out_data,
ln_mean_2_ptr,
ln_var_2_ptr);
}
}
};
template <typename T, typename DeviceContext>
class FusedAttentionGradKernel : public framework::OpKernel<T> { class FusedAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -790,7 +435,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -790,7 +435,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias); ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
} }
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context()); phi::fusion::AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
layer_norm_compute.ComputeBackward(x_data, layer_norm_compute.ComputeBackward(x_data,
d_ln_out_data, d_ln_out_data,
ln_scale_data, ln_scale_data,
...@@ -808,7 +453,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -808,7 +453,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias); input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
} }
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context()); phi::fusion::AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
} }
if (add_residual) { if (add_residual) {
...@@ -824,20 +469,364 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -824,20 +469,364 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace phi {
namespace plat = paddle::platform; namespace fusion {
template <typename T, typename Context>
void FusedAttentionKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &cache_kv,
const paddle::optional<DenseTensor> &src_mask,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *ln_mean,
DenseTensor *ln_var,
DenseTensor *ln_out,
DenseTensor *qkv_out,
DenseTensor *qkv_bias_out,
DenseTensor *transpose_out_2,
DenseTensor *qk_out,
DenseTensor *qktv_out,
DenseTensor *softmax_out,
DenseTensor *attn_dropout_mask_out,
DenseTensor *attn_dropout_out,
DenseTensor *src_mask_out,
DenseTensor *fmha_out,
DenseTensor *out_linear_out,
DenseTensor *dropout_mask_out,
DenseTensor *ln_mean_2,
DenseTensor *ln_var_2,
DenseTensor *bias_dropout_residual_out,
DenseTensor *cache_kv_out,
DenseTensor *out) {
using U = phi::funcs::LayerNormParamType<T>;
PD_REGISTER_STRUCT_KERNEL(fused_attention, // x: qkv's input [batch_size, seq_len, dim_embed]
GPU, // if transpose_qkv_wb is False
ALL_LAYOUT, // y: qkv's weight: [3, num_head, dim_head, dim_embed]
ops::FusedAttentionOpKernel, // if transpose_qkv_wb is True
float, // y: qkv's weight: [dim_embed, 3 * dim_embed]
double,
plat::float16) {} auto *x_p = &x;
PD_REGISTER_STRUCT_KERNEL(fused_attention_grad, auto *ln_scale_p = ln_scale.get_ptr();
auto *ln_bias_p = ln_bias.get_ptr();
auto *qkv_weight_p = &qkv_weight;
auto *qkv_bias_p = qkv_bias.get_ptr();
auto *cache_kv_p = cache_kv.get_ptr();
auto *src_mask_p = src_mask.get_ptr();
auto *out_linear_weight_p = &out_linear_weight;
auto *out_linear_bias_p = out_linear_bias.get_ptr();
auto *ln_scale_2_p = ln_scale_2.get_ptr();
auto *ln_bias_2_p = ln_bias_2.get_ptr();
const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
const bool is_upscale_in_train =
(dropout_implementation == "upscale_in_train");
phi::fusion::DropoutParam dropout_param2(dropout_fix_seed,
0,
is_test,
is_upscale_in_train,
dropout_rate,
nullptr,
dropout_seed);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
bool is_upscale_in_train_1 =
(attn_dropout_implementation == "upscale_in_train");
phi::DenseTensor *seed_1 = nullptr;
// get data ptr for qkv part.
const auto input_x_dims = x_p->dims();
const auto qkv_w_dims = qkv_weight_p->dims();
auto *x_data = x_p->data<T>();
auto *qkv_weight_data = qkv_weight_p->data<T>();
auto *qkv_bias_data =
(qkv_bias_p == nullptr) ? nullptr : qkv_bias_p->data<T>();
auto *qkv_out_data =
dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
auto *qkv_bias_out_data =
(qkv_bias_p == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_bias_out,
qkv_bias_out->numel() * sizeof(T));
// get data ptr for FMHA.
auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
transpose_out_2, transpose_out_2->numel() * sizeof(T));
auto *cache_kv_out_data =
(cache_kv_out == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(cache_kv_out,
cache_kv_out->numel() * sizeof(T));
auto *qk_out_data =
dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
auto *qktv_out_data =
dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
auto *src_mask_out_data =
(src_mask_p == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(src_mask_out,
src_mask_out->numel() * sizeof(T));
auto *softmax_out_data =
dev_ctx.template Alloc<T>(softmax_out, softmax_out->numel() * sizeof(T));
auto *attn_dropout_mask_out_data =
has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
attn_dropout_mask_out,
attn_dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(attn_dropout_out,
attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *fmha_out_data =
dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight_p->data<T>();
auto *out_linear_bias_data =
(out_linear_bias_p == nullptr) ? nullptr : out_linear_bias_p->data<T>();
auto *out_linear_out_data = dev_ctx.template Alloc<T>(
out_linear_out, out_linear_out->numel() * sizeof(T));
// get data ptr for bias+dropout+residual+layernorm
auto *dropout_mask_out_data =
has_dropout
? dev_ctx.template Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *final_out_data =
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
auto layer_norm_compute =
phi::fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
bool compute_bias = true;
if (qkv_bias_p == nullptr) {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true)
bool transB = transpose_qkv_wb ? false : true;
auto qkv_compute = phi::fusion::AttnMatMul<T>(
dev_ctx, false, transB, bsz_seq, output_size, input_size, compute_bias);
phi::fusion::AttnDropoutParam attn_dropout_param(is_test,
attn_dropout_implementation,
attn_dropout_rate,
is_upscale_in_train_1,
attn_dropout_fix_seed,
attn_dropout_seed,
seed_1);
auto fmha_ref_compute = phi::fusion::FMHARef<T>(
dev_ctx, batch_size, max_seq_len, num_head, dim_head, attn_dropout_param);
output_size = hidden_size;
// (transA, transB, compute_bias) = (false, false, false)
// NOTE(Yuang Liu): For general input size == output size, change the
// position won't have effects. For mp, the output size is mp_head * dkey
// which is actually the input size. While the input size is hidden size,
// which is actually the output size. So for out linear, switch the
// input size and output size.
auto out_linear_compute = phi::fusion::AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, input_size, output_size, false);
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, ln_epsilon);
if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale_p == nullptr ? nullptr : ln_scale_p->data<U>());
auto *ln_bias_data =
(ln_bias_p == nullptr ? nullptr : ln_bias_p->data<U>());
auto *ln_mean_data =
dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
auto *ln_var_data =
dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
auto *ln_out_data =
dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
layer_norm_compute.ComputeForward(x_data,
ln_scale_data,
ln_bias_data,
ln_out_data,
ln_mean_data,
ln_var_data);
qkv_compute.ComputeForward(
qkv_weight_p, ln_out, qkv_bias_p, qkv_out, qkv_bias_out);
} else {
qkv_compute.ComputeForward(
qkv_weight_p, x_p, qkv_bias_p, qkv_out, qkv_bias_out);
}
if (transpose_qkv_wb) {
// resize the output for fmha compute
qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
if (qkv_bias_p == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out,
cache_kv_p,
src_mask_p,
transpose_out_2,
cache_kv_out,
qk_out,
src_mask_out,
softmax_out,
attn_dropout_mask_out,
attn_dropout_out,
qktv_out,
fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out,
cache_kv_p,
src_mask_p,
transpose_out_2,
cache_kv_out,
qk_out,
src_mask_out,
softmax_out,
attn_dropout_mask_out,
attn_dropout_out,
qktv_out,
fmha_out);
}
if (transpose_qkv_wb) {
// resize the output back to make the shape compatible with infer shape
qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(
out_linear_weight_p, fmha_out, nullptr, out_linear_out, nullptr);
// tensor model parallel
phi::fusion::AllReduce<T>(*out_linear_out, ring_id, dev_ctx);
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(dev_ctx,
out_linear_out_data,
residual_ptr,
out_linear_bias_data,
final_out_data,
dropout_mask_out_data);
} else {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(
add_residual,
true,
errors::InvalidArgument("Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
const U *ln_scale_2_ptr = ln_scale_2_p ? ln_scale_2_p->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2_p ? ln_bias_2_p->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
bias_dropout_residual_out,
bias_dropout_residual_out->numel() * sizeof(T));
U *ln_mean_2_ptr =
dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
U *ln_var_2_ptr =
dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
out_linear_out_data,
residual_ptr,
out_linear_bias_data,
ln_scale_2_ptr,
ln_bias_2_ptr,
bias_dropout_residual_out_ptr,
dropout_mask_out_data,
final_out_data,
ln_mean_2_ptr,
ln_var_2_ptr);
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_attention,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
ops::FusedAttentionGradKernel, phi::fusion::FusedAttentionKernel,
float, phi::dtype::float16,
double, double,
plat::float16) {} float) {
phi::DataType data_type;
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::FLOAT32) {
data_type = phi::DataType::FLOAT32;
} else {
data_type = phi::DataType::FLOAT64;
}
kernel->OutputAt(0).SetDataType(data_type);
kernel->OutputAt(1).SetDataType(data_type);
kernel->OutputAt(3).SetDataType(data_type);
kernel->OutputAt(4).SetDataType(data_type);
kernel->OutputAt(15).SetDataType(data_type);
kernel->OutputAt(16).SetDataType(data_type);
}
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
ops::FusedAttentionGradKernel<float>,
ops::FusedAttentionGradKernel<double>,
ops::FusedAttentionGradKernel<plat::float16>);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/phi/core/errors.h"
namespace phi {
namespace fusion {
template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const phi::GPUContext &dev_ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<paddle::distributed::ProcessGroupNCCL *>(pg);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = paddle::distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = phi::ToNCCLDataType(tensor.dtype());
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = dev_ctx.GetPlace();
void *recvbuff =
dev_ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
} // namespace fusion
} // namespace phi
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <cuda_runtime_api.h> // NOLINT #include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h" #include "paddle/phi/kernels/autotune/gpu_timer.h"
......
...@@ -24,6 +24,7 @@ namespace cub = hipcub; ...@@ -24,6 +24,7 @@ namespace cub = hipcub;
#include <iostream> #include <iostream>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
......
...@@ -1137,14 +1137,32 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1137,14 +1137,32 @@ void ReduceKernel(const KPDevice& dev_ctx,
is_mean); is_mean);
} }
template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
void TensorReduceImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& x,
phi::DenseTensor* y,
const TransformOp& transform,
const std::vector<int>& origin_reduce_dims,
gpuStream_t stream,
bool is_mean = false) {
dev_ctx.template Alloc<Ty>(y);
ReduceKernel<Tx, Ty, ReduceOp, TransformOp>(
static_cast<const phi::GPUContext&>(dev_ctx),
x,
y,
transform,
origin_reduce_dims,
is_mean);
}
#endif #endif
template <typename DeviceContext, template <typename Context, typename T, size_t D, size_t R_D, typename Functor>
typename T, void ReduceFunctor(const Context& context,
size_t D,
size_t R_D,
typename Functor>
void ReduceFunctor(const DeviceContext& context,
const phi::DenseTensor& input, const phi::DenseTensor& input,
phi::DenseTensor* output, phi::DenseTensor* output,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
...@@ -1183,7 +1201,7 @@ void ReduceFunctor(const DeviceContext& context, ...@@ -1183,7 +1201,7 @@ void ReduceFunctor(const DeviceContext& context,
#define HANDLE_REDUCE_DIM(NDIM, RDIM) \ #define HANDLE_REDUCE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \ if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, Functor>( \ ReduceFunctor<Context, OutT, NDIM, RDIM, Functor>( \
dev_ctx, input, output, dims, keep_dim); \ dev_ctx, input, output, dims, keep_dim); \
} }
//////////////// HandleLargeDim //////////////// HandleLargeDim
...@@ -1220,8 +1238,8 @@ inline void GetShuffledDim(const DDim& src_dims, ...@@ -1220,8 +1238,8 @@ inline void GetShuffledDim(const DDim& src_dims,
} }
} }
template <typename DeviceContext, typename OutT> template <typename Context, typename OutT>
void GetShuffledInput(const DeviceContext& dev_ctx, void GetShuffledInput(const Context& dev_ctx,
const phi::DenseTensor& input, const phi::DenseTensor& input,
phi::DenseTensor* shuffled_input, phi::DenseTensor* shuffled_input,
const std::vector<int64_t>& dims) { const std::vector<int64_t>& dims) {
...@@ -1232,19 +1250,19 @@ void GetShuffledInput(const DeviceContext& dev_ctx, ...@@ -1232,19 +1250,19 @@ void GetShuffledInput(const DeviceContext& dev_ctx,
shuffled_input->Resize(shuffled_dims); shuffled_input->Resize(shuffled_dims);
dev_ctx.template Alloc<OutT>(shuffled_input); dev_ctx.template Alloc<OutT>(shuffled_input);
phi::funcs::TransposeNormal<DeviceContext, OutT> trans; phi::funcs::TransposeNormal<Context, OutT> trans;
trans(dev_ctx, input, shuffled_input, perm_axis); trans(dev_ctx, input, shuffled_input, perm_axis);
} }
template <typename DeviceContext, typename OutT, typename Functor> template <typename Context, typename OutT, typename Functor>
void HandleLargeDim(const DeviceContext& dev_ctx, void HandleLargeDim(const Context& dev_ctx,
const phi::DenseTensor& input, const phi::DenseTensor& input,
phi::DenseTensor* output, phi::DenseTensor* output,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim) { bool keep_dim) {
// shuffle the reduced dim to the end // shuffle the reduced dim to the end
phi::DenseTensor shuffled_input; phi::DenseTensor shuffled_input;
GetShuffledInput<DeviceContext, OutT>(dev_ctx, input, &shuffled_input, dims); GetShuffledInput<Context, OutT>(dev_ctx, input, &shuffled_input, dims);
// transpose to 2D tensor whose shape is {unreduced, reduced}. // transpose to 2D tensor whose shape is {unreduced, reduced}.
const int64_t unreduced = output->numel(); const int64_t unreduced = output->numel();
...@@ -1266,15 +1284,15 @@ void HandleLargeDim(const DeviceContext& dev_ctx, ...@@ -1266,15 +1284,15 @@ void HandleLargeDim(const DeviceContext& dev_ctx,
DDim output_dim = output->dims(); DDim output_dim = output->dims();
output->ResizeAndAllocate({unreduced}); output->ResizeAndAllocate({unreduced});
ReduceFunctor<DeviceContext, OutT, 2, 1, Functor>( ReduceFunctor<Context, OutT, 2, 1, Functor>(
dev_ctx, shuffled_input, output, {1}, keep_dim); dev_ctx, shuffled_input, output, {1}, keep_dim);
output->ResizeAndAllocate(output_dim); output->ResizeAndAllocate(output_dim);
} }
////////////// ReduceKernel ////////////// ReduceKernel
template <typename DeviceContext, typename T, typename OutT, typename Functor> template <typename Context, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const DeviceContext& dev_ctx, void ReduceKernelImpl(const Context& dev_ctx,
const phi::DenseTensor& input, const phi::DenseTensor& input,
phi::DenseTensor* output, phi::DenseTensor* output,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
...@@ -1295,7 +1313,7 @@ void ReduceKernelImpl(const DeviceContext& dev_ctx, ...@@ -1295,7 +1313,7 @@ void ReduceKernelImpl(const DeviceContext& dev_ctx,
int ndim = input.dims().size(); int ndim = input.dims().size();
int rdim = dims.size(); int rdim = dims.size();
if (ndim > 6) { if (ndim > 6) {
HandleLargeDim<DeviceContext, OutT, Functor>( HandleLargeDim<Context, OutT, Functor>(
dev_ctx, input, output, dims, keep_dim); dev_ctx, input, output, dims, keep_dim);
} else { } else {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace phi {
namespace fusion {
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
float epsilon,
int64_t batch_size,
int64_t feature_size)
: dev_ctx_(dev_ctx),
epsilon_(epsilon),
batch_size_(batch_size),
feature_size_(feature_size) {}
~AttnLayerNorm() {}
void ComputeForward(const InType* x_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* bias_data,
OutType* y_data,
phi::funcs::LayerNormParamType<T>* mean_data,
phi::funcs::LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();
switch (phi::funcs::GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
phi::funcs::LayerNormForward<T,
phi::funcs::LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
y_data,
mean_data,
var_data,
epsilon_,
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(
phi::errors::InvalidArgument("Feature_size must be larger than 1"));
break;
}
}
void ComputeBackward(const T* x_data,
const T* d_y_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* mean_data,
const phi::funcs::LayerNormParamType<T>* var_data,
T* d_x_data,
phi::funcs::LayerNormParamType<T>* d_scale_data,
phi::funcs::LayerNormParamType<T>* d_bias_data) {
phi::funcs::LayerNormBackward<T, phi::funcs::LayerNormParamType<T>>(
x_data,
d_y_data,
scale_data,
mean_data,
var_data,
d_x_data,
d_scale_data,
d_bias_data,
epsilon_,
batch_size_,
feature_size_,
dev_ctx_);
}
private:
const phi::GPUContext& dev_ctx_;
int64_t batch_size_;
int64_t feature_size_;
float epsilon_;
};
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/backends/dynload/cublasLt.h"
#endif
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
namespace fusion {
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
public:
// (m, n, k) = bsz_seq, output_size, input_size
AttnMatMul(const phi::GPUContext& dev_ctx,
bool transA,
bool transB,
int bsz_seq,
int output_size,
int input_size,
bool compute_bias)
: dev_ctx_(dev_ctx),
transA_(transA),
transB_(transB),
bsz_seq_(bsz_seq),
output_size_(output_size),
input_size_(input_size),
compute_bias_(compute_bias) {}
void ComputeForward(const phi::DenseTensor* weight,
const phi::DenseTensor* input,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* bias_out,
bool fused = false) {
VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={"
<< weight->dims() << "}, output.shape={" << output->dims()
<< "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_
<< ", input_size=" << input_size_ << ", transA=" << transA_
<< ", transB=" << transB_ << ", compute_bias=" << compute_bias_
<< ", fused=" << fused;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
PADDLE_ENFORCE_EQ(
!output || output == bias_out,
true,
phi::errors::InvalidArgument(
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."));
phi::funcs::ComputeFusedGemmEpilogueForward<T>(dev_ctx_,
input,
weight,
bias,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
bias_out,
nullptr);
return;
}
#endif
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
// (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
blas.GEMM(transA,
transB,
bsz_seq_,
output_size_,
input_size_,
alpha,
input->data<T>(),
weight->data<T>(),
beta,
output->data<T>());
if (compute_bias_) {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
}
}
void ComputeBackward(const phi::DenseTensor* input,
const phi::DenseTensor* weight,
const phi::DenseTensor* d_output,
phi::DenseTensor* d_input,
phi::DenseTensor* d_weight,
phi::DenseTensor* d_bias,
bool use_addto = false,
bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
d_output,
input,
weight,
nullptr,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
d_input,
d_weight,
d_bias,
use_addto);
return;
}
#endif
T alpha = static_cast<T>(1.0);
T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
T beta_dB = static_cast<T>(0.0);
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
if (!transA_) {
// forward: gemm-nt
if (transB_) {
// backward: gemm-tn, dB = (dC)^T * A
if (d_weight) {
int dB_m = output_size_;
int dB_n = input_size_;
int dB_k = bsz_seq_;
T* dB_output_ptr = d_weight->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
dB_m,
dB_n,
dB_k,
alpha,
d_output->data<T>(),
input->data<T>(),
beta_dB,
dB_output_ptr);
}
// backward: gemm-nn, dA = dC * B
if (d_input) {
int dA_m = bsz_seq_;
int dA_n = input_size_;
int dA_k = output_size_;
T* dA_output_ptr = d_input->data<T>();
blas.GEMM(CblasNoTrans,
CblasNoTrans,
dA_m,
dA_n,
dA_k,
alpha,
d_output->data<T>(),
weight->data<T>(),
beta_dA,
dA_output_ptr);
}
} else { // fw: gemm-nn
// backward: gemm-tn, dB = A^T * dC
if (d_weight) {
int dB_m = input_size_;
int dB_n = output_size_;
int dB_k = bsz_seq_;
T* dB_output_ptr = d_weight->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
dB_m,
dB_n,
dB_k,
alpha,
input->data<T>(),
d_output->data<T>(),
beta_dB,
dB_output_ptr);
}
// backward: gemm-nt, dA = dC * B^T
if (d_input) {
int dA_m = bsz_seq_;
int dA_n = input_size_;
int dA_k = output_size_;
T* dA_output_ptr = d_input->data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
dA_m,
dA_n,
dA_k,
alpha,
d_output->data<T>(),
weight->data<T>(),
beta_dA,
dA_output_ptr);
}
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=T/N)"
"parameters."));
}
if (compute_bias_ && d_bias) {
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
// -> {3} or {0,1,2,3,4} -> {3,4}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
bool support_case_3 =
(input_dims.size() == 4 && output_dims.size() == 1 &&
input_dims[3] == output_dims[0]);
bool support_case_4 =
(input_dims.size() == 5 && output_dims.size() == 2 &&
input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]);
gpuStream_t stream = dev_ctx_.stream();
if (support_case_1 || support_case_2) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1},
stream);
} else if (support_case_3 || support_case_4) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1, 2},
stream);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
}
}
private:
const phi::GPUContext& dev_ctx_;
bool transA_;
bool transB_;
int bsz_seq_;
int output_size_;
int input_size_;
int compute_bias_;
};
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
namespace fusion {
class AttnDropoutParam {
public:
AttnDropoutParam() {
is_test_ = false;
dropout_implementation_ = "downgrade_in_infer";
dropout_prob_ = 0.5;
is_upscale_in_train_ = false;
is_fix_seed_ = false;
seed_val_ = 0;
seed_ = nullptr;
}
AttnDropoutParam(bool is_test,
const std::string dropout_implementation,
float dropout_prob,
bool is_upscale_in_train,
bool is_fix_seed,
int seed_val,
const phi::DenseTensor* seed) {
is_test_ = is_test;
dropout_implementation_ = dropout_implementation;
dropout_prob_ = dropout_prob;
is_upscale_in_train_ = is_upscale_in_train;
is_fix_seed_ = is_fix_seed;
seed_val_ = seed_val;
seed_ = seed;
}
bool is_test_;
std::string dropout_implementation_;
float dropout_prob_;
bool is_upscale_in_train_;
bool is_fix_seed_;
int seed_val_;
const phi::DenseTensor* seed_;
};
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset) {
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(
head_dim % PackSize,
0,
errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
num_head,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <typename T>
class FMHARef {
public:
FMHARef(const phi::GPUContext& dev_ctx,
int64_t batch_size,
int64_t seq_len,
int64_t num_head,
int64_t head_dim,
AttnDropoutParam param)
: dev_ctx_(dev_ctx),
batch_size_(batch_size),
seq_len_(seq_len),
num_head_(num_head),
head_dim_(head_dim),
dropout_param_(param) {}
~FMHARef() {}
void ComputeForward(const phi::DenseTensor& qkv_input_tensor,
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* transpose_2_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
std::vector<int> perm_1 = {2, 0, 3, 1, 4};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor);
T* qkv_data = transpose_2_out_tensor->data<T>();
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
auto out_seq_len = seq_len_;
if (cache_kv_tensor) {
// kv [2, bs, num_head, seq_len, head_dim]
auto kv_tensor = transpose_2_out_tensor->Slice(1, 3);
phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor);
out_seq_len = cache_kv_out_tensor->dims()[3];
}
int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
T* q_ptr = qkv_data;
T* k_ptr = nullptr;
T* v_ptr = nullptr;
if (cache_kv_tensor) {
int64_t k_size = cache_kv_out_tensor->numel() / 2;
k_ptr = cache_kv_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
} else {
int64_t k_size = q_size;
k_ptr = q_ptr + q_size;
v_ptr = k_ptr + k_size;
}
{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const phi::DenseTensor*> ins = {&q_tensor};
std::vector<phi::DenseTensor*> outs = {&q_tensor};
phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
q_ptr,
k_ptr,
beta,
qk_out_data,
gemm_batch_size,
stride_a,
stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
phi::fusion::LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
src_mask_tensor->data<T>(),
softmax_out_data,
batch_size_,
num_head_,
seq_len_,
dev_ctx_.stream());
} else {
std::vector<const phi::DenseTensor*> ins;
std::vector<phi::DenseTensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
}
} else {
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
}
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = out_seq_len;
alpha = static_cast<T>(1.0);
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
phi::funcs::DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
dropout_param_.dropout_prob_,
dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_,
dropout_param_.seed_val_,
static_cast<const phi::DenseTensor&>(*softmax_out_tensor),
dropout_param_.seed_,
dropout_mask_out_tensor,
dropout_out_tensor,
false);
T* dropout_out_data = dropout_out_tensor->data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
dropout_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
} else {
// softmax_out * v, batched_gemm
// output shape: [batch_size, num_heads, seq_len, head_dim]
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
softmax_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeForwardWithoutTranspose(
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>();
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
auto out_seq_len = seq_len_;
if (cache_kv_tensor) {
// kv [2, bs, num_head, seq_len, head_dim]
phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat(dev_ctx_,
{*cache_kv_tensor, *kv_transpose_out_tensor},
3,
cache_kv_out_tensor);
out_seq_len = cache_kv_out_tensor->dims()[3];
}
int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
T* q_ptr = q_transpose_out_tensor->data<T>();
T* k_ptr = nullptr;
T* v_ptr = nullptr;
if (cache_kv_tensor) {
int64_t k_size = cache_kv_out_tensor->numel() / 2;
k_ptr = cache_kv_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
} else {
int64_t k_size = q_size;
k_ptr = kv_transpose_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
}
{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const phi::DenseTensor*> ins = {q_transpose_out_tensor};
std::vector<phi::DenseTensor*> outs = {q_transpose_out_tensor};
phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
q_ptr,
k_ptr,
beta,
qk_out_data,
gemm_batch_size,
stride_a,
stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
phi::fusion::LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
src_mask_tensor->data<T>(),
softmax_out_data,
batch_size_,
num_head_,
seq_len_,
dev_ctx_.stream());
} else {
std::vector<const phi::DenseTensor*> ins;
std::vector<phi::DenseTensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
}
} else {
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
}
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = out_seq_len;
alpha = static_cast<T>(1.0);
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
phi::funcs::DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
dropout_param_.dropout_prob_,
dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_,
dropout_param_.seed_val_,
static_cast<const phi::DenseTensor&>(*softmax_out_tensor),
dropout_param_.seed_,
dropout_mask_out_tensor,
dropout_out_tensor,
false);
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
dropout_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
} else {
// softmax_out * v, batched_gemm
// output shape: [batch_size, num_heads, seq_len, head_dim]
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
softmax_out_data,
v_ptr,
beta,
qktv_out_data,
gemm_batch_size,
stride_a,
stride_b);
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
if (!padding_offset_tensor) {
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} else {
InvokeTransposeRemovePadding<T>(dev_ctx_,
qktv_out_data,
fmha_out_data,
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
}
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor& softmax_out_tensor,
const phi::DenseTensor& dropout_mask_out_tensor,
const phi::DenseTensor& dropout_out_tensor,
const phi::DenseTensor& qk_out_tensor,
const phi::DenseTensor& src_mask_out_tensor,
const phi::DenseTensor& fmha_out_grad_tensor,
phi::DenseTensor* qktv_out_grad_tensor,
phi::DenseTensor* dropout_out_grad_tensor,
phi::DenseTensor* softmax_out_grad_tensor,
phi::DenseTensor* src_mask_out_grad_tensor,
phi::DenseTensor* qk_out_grad_tensor,
phi::DenseTensor* transpose_2_out_grad_tensor,
phi::DenseTensor* src_mask_grad_tensor,
phi::DenseTensor* qkv_input_grad_tensor) {
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
int k_size = q_size;
int softmax_axis = -1;
T* qkv_grad_data = transpose_2_out_grad_tensor->data<T>();
T* q_grad_ptr = qkv_grad_data;
T* k_grad_ptr = q_grad_ptr + q_size;
T* v_grad_ptr = k_grad_ptr + k_size;
const T* qkv_data = transpose_2_out_tensor.data<T>();
const T* q_ptr = qkv_data;
const T* k_ptr = q_ptr + q_size;
const T* v_ptr = k_ptr + k_size;
const T* softmax_out_data = softmax_out_tensor.data<T>();
T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();
// transpose bw
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor);
// recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) =
// qktv_out_data(out)
CBLAS_TRANSPOSE transA = CblasTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = head_dim_;
int gemm_k = seq_len_;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
// bw: dy = x^t * dout
if (dropout_param_.dropout_prob_) {
const T* dropout_out_data = dropout_out_tensor.data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
dropout_out_data,
qktv_out_grad_data,
beta,
v_grad_ptr,
gemm_batch_size,
stride_a,
stride_b);
} else {
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
softmax_out_data,
qktv_out_grad_data,
beta,
v_grad_ptr,
gemm_batch_size,
stride_a,
stride_b);
}
// bw: dx = dout * y^t
transA = CblasNoTrans;
transB = CblasTrans;
gemm_m = seq_len_;
gemm_n = seq_len_;
gemm_k = head_dim_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
qktv_out_grad_data,
v_ptr,
beta,
dropout_out_grad_data,
gemm_batch_size,
stride_a,
stride_b);
} else {
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
qktv_out_grad_data,
v_ptr,
beta,
softmax_out_grad_data,
gemm_batch_size,
stride_a,
stride_b);
}
// dropout bw
if (dropout_param_.dropout_prob_) {
phi::funcs::DropoutGradGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
false,
dropout_param_.dropout_prob_,
dropout_param_.is_upscale_in_train_,
static_cast<const phi::DenseTensor&>(*dropout_out_grad_tensor),
dropout_mask_out_tensor,
softmax_out_grad_tensor,
false);
}
if (src_mask_tensor != nullptr) {
phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
softmax_out_tensor,
*softmax_out_grad_tensor,
softmax_axis,
src_mask_out_grad_tensor);
// recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out +
// src_mask
// Special case when dy is not needed and dx doesn't reduce
if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr &&
qk_out_tensor.dims() == src_mask_out_tensor.dims()) {
VLOG(4) << "Special case when dy is not needed and dx doesn't "
"reduce";
phi::Copy(dev_ctx_,
*src_mask_out_grad_tensor,
dev_ctx_.GetPlace(),
false,
qk_out_grad_tensor);
} else {
PADDLE_THROW(errors::InvalidArgument(
"Only used for the backward elementwise_add op when"
"dy is not needed and dx is not reduce"));
return;
}
} else {
phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
softmax_out_tensor,
*softmax_out_grad_tensor,
softmax_axis,
qk_out_grad_tensor);
}
T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
// NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
// alpha = 1.0 in backward.
alpha = static_cast<T>(1.0);
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA = CblasTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
qk_out_grad_data,
q_ptr,
beta,
k_grad_ptr,
gemm_batch_size,
stride_a,
stride_b);
// dx (seq_len * head_dim) = dout * y
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
transA = CblasNoTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA,
transB,
gemm_m,
gemm_n,
gemm_k,
alpha,
qk_out_grad_data,
k_ptr,
beta,
q_grad_ptr,
gemm_batch_size,
stride_a,
stride_b);
// transpose bw
std::vector<int> perm_1 = {1, 3, 0, 2, 4};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor);
}
private:
const phi::GPUContext& dev_ctx_;
int64_t batch_size_;
int64_t seq_len_;
int64_t num_head_;
int64_t head_dim_;
AttnDropoutParam dropout_param_;
};
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_common.h"
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
namespace phi {
namespace fusion {
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = phi::funcs::LayerNormParamType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
return phi::GeluFwd<T, true>(x);
}
};
/**
*@brief the gelu grad functor
*/
template <typename T>
struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const {
using U = phi::funcs::LayerNormParamType<T>;
auto casted_x = static_cast<U>(x);
auto first =
static_cast<U>(0.5) *
(static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));
auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
exp(-static_cast<U>(0.5) * casted_x * casted_x);
return static_cast<T>((first + second));
}
};
/**
* @brief dst = dropout(activation(src + bias));
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T,
typename MaskType,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedDropoutActBias(
Functor act,
const uint64_t seed,
const uint64_t rows,
const uint64_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst,
MaskType *mask,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
const T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
phi::fusion::FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
false,
true,
Functor,
InType,
OutType>(
r,
i,
cols,
&state,
dropout_prob,
factor,
src,
nullptr,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
act,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
}
template <typename T,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedActBias(Functor act,
const uint64_t elem_cnt,
const uint64_t cols,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
LoadInType src_vec;
LoadT bias_vec;
StoreOutType out_vec;
LoadFloat dequant_out_scale_vec;
for (int32_t idx = global_thread_idx * VecSize,
step = blockDim.x * gridDim.x * VecSize;
idx < elem_cnt;
idx += step) {
const int32_t col_idx = idx % cols;
phi::Load<InType, VecSize>(&src[idx], &src_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_idx],
&dequant_out_scale_vec);
if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
}
#pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
T tmp;
if (std::is_same<InType, int32_t>::value) {
tmp = static_cast<T>(static_cast<float>(src_vec[unroll_idx]) *
dequant_out_scale_vec[unroll_idx]);
if (bias) {
tmp = static_cast<T>(act(tmp + bias_vec[unroll_idx]));
} else {
tmp = static_cast<T>(act(tmp));
}
out_vec[unroll_idx] = phi::funcs::quant_helper(tmp,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
if (bias) {
out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
}
}
}
phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
}
}
/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T,
typename MaskType,
typename Functor,
typename InType = T,
typename OutType = T>
void LaunchDropoutActBias(Functor act_functor,
const uint64_t seed,
const uint32_t rows,
const uint32_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const InType *src,
const T *bias,
OutType *dst,
MaskType *mask_data,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
phi::fusion::SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
phi::fusion::SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config =
phi::fusion::Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
if (is_test) {
const int32_t elem_cnt = rows * cols;
const int32_t pack_num = elem_cnt / VecSize;
const int32_t tmp_cols = cols / VecSize;
int block_size =
std::max(static_cast<int32_t>(32), std::min(tmp_cols, 128));
const int grid_size = std::max(static_cast<int32_t>(1),
(pack_num + block_size - 1) / block_size);
FusedActBias<T, VecSize, Functor, InType, OutType>
<<<grid_size, block_size, 0, ctx.stream()>>>(act_functor,
elem_cnt,
cols,
src,
bias,
dst,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
} else {
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
rows,
cols,
increment,
dropout_prob,
is_upscale_in_train,
is_test,
src,
bias,
dst,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
} else {
FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
rows,
cols,
increment,
dropout_prob,
is_upscale_in_train,
is_test,
src,
bias,
dst,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
}
/*
* @brief calculate the grad of no bias
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActGrad(Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T factor,
const int64_t size,
T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
LoadT src_vec;
MaskLoadT mask_vec;
phi::Load<T, VecSize>(&dout[i], &dout_vec);
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
phi::Load<T, VecSize>(&src[i], &src_vec);
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
}
phi::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 cols data by 8*VecSize warps
*/
template <typename T,
typename MaskType,
int BlockSizeX,
int BlockSizeY,
int VecSize,
typename Functor,
int THREADS_PER_CTA = BlockSizeX *BlockSizeY>
__global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const T factor,
const int64_t rows,
const int64_t cols,
T *dx,
T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
LoadT dout_vec;
LoadT src_vec;
LoadT bias_vec;
MaskLoadT mask_vec;
phi::Load<T, VecSize>(&dout[index], &dout_vec);
phi::Load<T, VecSize>(&src[index], &src_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
StoreT dx_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T tmp = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]);
dx_vec[i] = val;
tmp_sum[i] += val;
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
phi::fusion::CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(
tmp_sum, dbias, cols);
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBiasGrad(Functor act_functor,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const float dropout_prob,
const bool is_upscale_in_train,
const uint32_t rows,
const uint32_t cols,
T *dx,
T *dbias,
const phi::GPUContext &ctx) {
const T zero = static_cast<T>(0.0);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0 / (1.0 - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
const auto threads = 8;
const auto blocks =
std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
FusedDropoutActBiasGrad<T, MaskType, 8, 128, VecSize, Functor>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
dout,
mask,
src,
bias,
factor,
rows,
cols,
dx,
dbias);
} else {
FusedDropoutActBiasGrad<T, MaskType, 8, 128, 1, Functor>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
dout,
mask,
src,
bias,
factor,
rows,
cols,
dx,
dbias);
}
} else {
const uint64_t n = rows * cols;
phi::backends::gpu::GpuLaunchConfig config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedDropoutActGrad<T, MaskType, VecSize, Functor>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
} else {
FusedDropoutActGrad<T, MaskType, 1, Functor>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
}
}
}
} // namespace fusion
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>
#endif
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace phi {
namespace fusion {
#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)
/**
* get the threads for fused_residual_dropout_bias:
* 1D blocks: blockDim.x = cols
* 2D grids: gridDim.y = rows
*/
inline phi::backends::gpu::GpuLaunchConfig Get1DBlocksAnd2DGrids(
const phi::GPUContext &ctx,
const uint32_t rows,
const uint32_t cols,
const int vec_size) {
const uint32_t tmp_cols = cols / vec_size;
// NOTE(wangxi): We set max_block_size to 512, for `FusedResidualDropoutBias`
// needs too many register resources. If data_type is float16, CUDA
// error(701) will occur when block_size is 1024. Which error is
// 'cudaErrorLaunchOutOfResources', this indicates that a launch did not
// occur because it did not have appropriate resources.
// Of course, this kernel can be optimized later to reduce the use
// of registers.
int threads = std::max(static_cast<uint32_t>(32),
std::min(tmp_cols,
static_cast<uint32_t>(std::min(
ctx.GetMaxThreadsPerBlock(), 512))));
const auto blocks_x =
std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
phi::backends::gpu::GpuLaunchConfig config;
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.thread_per_block.x = threads;
return config;
}
template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
float *data);
template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
}
template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
data[1] = curand_uniform(state);
}
template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
float *data) {
float4 rand4 = curand_uniform4(state);
data[0] = rand4.x;
data[1] = rand4.y;
data[2] = rand4.w;
data[3] = rand4.z;
}
template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
float *data) {
RandVec<4>(state, data);
RandVec<4>(state, data + 4);
}
template <typename T>
inline void SetZero(const phi::GPUContext &ctx, T *ptr, const size_t size) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
}
/**
* reduce the sum of 128 cols data by 8*VecSize warps
**/
template <typename T, int VecSize, int BlockSizeX, int BlockSizeY>
inline __device__ void CalculateDBias(const T *tmp_sum,
T *dbias,
const int cols) {
// save temporary sum to cache and do transpose
__shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
for (int i = 0; i < VecSize; i++) {
cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
}
__syncthreads();
// reduce sum
T sum[2] = {static_cast<T>(0)};
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 5; // warp id
int y = tid & 31; // thread id on warp 0~31
// need BlockSizeX * VecSize warps
for (int j = x; j < BlockSizeX * VecSize; j += 32) {
// reduce 128 to 32
#pragma unroll
for (int i = 0; i < (BlockSizeY >> 5); i++) {
sum[(j >> 5)] += cache[j][y + i * 32];
}
}
int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
// reduce 32 to 1
for (int i = 0; i < reduce_num_pre_thread; i++) {
sum[i] = phi::funcs::WarpReduceSum(sum[i]);
}
// save sum to dbias
if (y == 0 && x < BlockSizeX * VecSize) {
for (int i = 0; i < reduce_num_pre_thread; i++) {
int bias_id = blockIdx.x * BlockSizeX * VecSize + x + i * 32;
if (bias_id < cols) {
dbias[bias_id] = sum[i];
}
}
}
}
template <typename T>
inline __device__ T GetFactor(const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test) {
T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
: static_cast<T>(1.0f);
if (is_test) {
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
return factor;
}
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/backends/dynload/cublasLt.h"
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_common.h"
#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h"
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace phi {
namespace fusion {
struct DropoutParam {
uint64_t seed;
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
int increment{};
const phi::DenseTensor* tensor_seed;
int seed_val;
DropoutParam() {
fix_seed = false;
seed = 0;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
DropoutParam(bool fix_seed_,
uint64_t seed_,
bool is_test_,
bool is_upscale_in_train_,
float dropout_prob_,
const phi::DenseTensor* tensor_seed_,
int seed_val_) {
fix_seed = fix_seed_;
seed = seed_;
is_test = is_test_;
is_upscale_in_train = is_upscale_in_train_;
dropout_prob = dropout_prob_;
tensor_seed = tensor_seed_;
seed_val = seed_val_;
}
int UpdateSeedAndIncrement(const phi::GPUContext& dev_ctx, const int offset) {
uint64_t tmp_increment;
phi::funcs::GetSeedDataAndIncrement(dev_ctx,
tensor_seed,
fix_seed,
seed_val,
offset,
&seed,
&tmp_increment);
increment = static_cast<int>(tmp_increment);
return increment;
}
};
template <typename T>
struct DataTypeTraits {
using DataType = T;
};
template <>
struct DataTypeTraits<phi::dtype::float16> {
// Since LayerNormDirectCUDAFunctor register half type, we need to convert
// phi::float16 to half.
using DataType = half;
};
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutHelper {
private:
int GetIncrement(const phi::GPUContext& ctx) {
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx,
static_cast<uint64_t>(rows_),
static_cast<uint64_t>(cols_),
real_vec_size);
int increment = ((cols_ - 1) / (config.thread_per_block.x *
config.block_per_grid.x * real_vec_size) +
1) *
real_vec_size;
increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment);
return increment;
}
public:
FusedDropoutHelper() {}
FusedDropoutHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param) {
rows_ = rows;
cols_ = cols;
dropout_param_ = dropout_param;
}
// out = residual + dropout( src + bias )
void ResidualDropoutBias(const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
rows_,
cols_,
increment,
dropout_param_.seed,
dropout_param_.dropout_prob,
dropout_param_.is_test,
dropout_param_.is_upscale_in_train,
src,
residual,
bias,
mask,
out,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
const T* d_out,
const MaskType* mask,
T* d_src,
T* d_residual,
T* d_bias) {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out,
mask,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
if (d_residual) {
phi::memory_utils::Copy(ctx.GetPlace(),
d_residual,
ctx.GetPlace(),
d_out,
rows_ * cols_ * sizeof(T),
ctx.stream());
}
}
// out = dropout(activation(src + bias))
void DropoutActBias(const phi::GPUContext& ctx,
const InType* src,
const T* bias,
const std::string& act_method,
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
if (FLAGS_use_fast_math) {
phi::fusion::FastGeluFunctor<T> fast_gelu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::fusion::FastGeluFunctor<T>,
InType,
OutType>(
fast_gelu,
dropout_param_.seed,
rows_,
cols_,
dropout_param_.increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
phi::fusion::GeluFunctor<T> gelu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::fusion::GeluFunctor<T>,
InType,
OutType>(
gelu,
dropout_param_.seed,
rows_,
cols_,
dropout_param_.increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
} else if (act_method == "relu") {
phi::funcs::ReluFunctor<T> relu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(
relu,
dropout_param_.seed,
rows_,
cols_,
increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
PADDLE_THROW(errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
void DropoutActBiasGrad(const phi::GPUContext& ctx,
const T* dout,
const T* src,
const T* bias,
const MaskType* mask,
T* d_src,
T* d_bias,
const std::string& act_method) {
if (act_method == "gelu") {
phi::funcs::GeluGradFunctor<T> gelu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else if (act_method == "relu") {
phi::funcs::ReluGradFunctor<T> relu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::ReluGradFunctor<T>>(
relu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else {
PADDLE_THROW(errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
protected:
int rows_;
int cols_;
DropoutParam dropout_param_;
};
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutLayerNormHelper
: public FusedDropoutHelper<T, MaskType, InType, OutType> {
public:
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows,
const int cols,
const float epsilon) {
using U = phi::funcs::LayerNormParamType<T>;
this->rows_ = rows;
this->cols_ = cols;
epsilon_ = epsilon;
}
FusedDropoutLayerNormHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param,
const float epsilon)
: FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) {
using U = phi::funcs::LayerNormParamType<T>;
epsilon_ = epsilon;
}
// call layer_norm
void LayerNorm(const phi::GPUContext& ctx,
const InType* src,
const phi::funcs::LayerNormParamType<T>* gamma,
const phi::funcs::LayerNormParamType<T>* beta,
OutType* out,
phi::funcs::LayerNormParamType<T>* mean,
phi::funcs::LayerNormParamType<T>* variance) {
using InDataType = typename DataTypeTraits<InType>::DataType;
using OutDataType = typename DataTypeTraits<OutType>::DataType;
phi::LayerNormDirectCUDAFunctor<InDataType,
phi::funcs::LayerNormParamType<T>>
layer_norm;
std::vector<int> src_shape{this->rows_, this->cols_};
layer_norm(ctx.stream(),
reinterpret_cast<const InDataType*>(src),
src_shape,
beta,
gamma,
reinterpret_cast<OutDataType*>(out),
mean,
variance,
1,
epsilon_);
}
void LayerNormGrad(const phi::GPUContext& ctx,
const T* dout,
const T* src,
const phi::funcs::LayerNormParamType<T>* gamma,
const phi::funcs::LayerNormParamType<T>* mean,
const phi::funcs::LayerNormParamType<T>* variance,
T* d_src,
phi::funcs::LayerNormParamType<T>* d_scale,
phi::funcs::LayerNormParamType<T>* d_bias) {
using U = phi::funcs::LayerNormParamType<T>;
phi::funcs::LayerNormBackward<T, U>(src,
dout,
gamma,
mean,
variance,
d_src,
d_scale,
d_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
}
// out = layernorm(residual + dropout(src + bias))
template <typename P = phi::funcs::LayerNormParamType<T>,
bool is_same_type = false>
void LayernormResidualDropoutBias(
const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
const P* gamma,
const P* beta,
T* dropout_out,
MaskType* mask,
OutType* out,
phi::funcs::LayerNormParamType<T>* mean,
phi::funcs::LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using U = phi::funcs::LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
vec_size = 1;
}
int threads = phi::funcs::GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T,
MaskType,
U,
is_same_type,
InType,
OutType>(
this->rows_,
this->cols_,
increment,
this->dropout_param_.seed,
this->dropout_param_.dropout_prob,
epsilon_,
this->dropout_param_.is_upscale_in_train,
this->dropout_param_.is_test,
src,
residual,
bias,
gamma,
beta,
mask,
dropout_out,
out,
mean,
variance,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
template <typename P = phi::funcs::LayerNormParamType<T>,
bool is_same_type = false>
void LayernormResidualDropoutBiasGrad(
const phi::GPUContext& ctx,
const T* d_out,
const T* layernorm_src,
const MaskType* mask,
const P* gamma,
const phi::funcs::LayerNormParamType<T>* mean,
const phi::funcs::LayerNormParamType<T>* variance,
T* d_layernorm_src,
P* d_scale,
P* d_layernorm_bias,
T* d_dropout_src,
T* d_bias,
T* d_residual) {
using U = phi::funcs::LayerNormParamType<T>;
bool can_call_1024_kernel = false;
// Fast impl for cases when cols is 1024 and linear_bias is nullptr.
// In fact, linear_bias is not nullptr is also feasible for impl.
// Here, we do not support it.
if (this->cols_ == 1024 && d_bias == nullptr && d_scale != nullptr &&
d_layernorm_bias != nullptr && sizeof(T) <= 4) {
can_call_1024_kernel = true;
}
VLOG(6) << "LaunchLayernormResidualDropoutGrad = " << can_call_1024_kernel;
if (can_call_1024_kernel) {
LaunchLayernormResidualDropoutGrad<T, U, MaskType, is_same_type>(
ctx,
this->rows_,
this->cols_,
epsilon_,
this->dropout_param_.dropout_prob,
this->dropout_param_.is_upscale_in_train,
d_out,
layernorm_src,
gamma,
mean,
variance,
mask,
d_scale,
d_layernorm_bias,
d_residual,
d_dropout_src);
} else {
phi::funcs::LayerNormBackward<T, U, is_same_type>(layernorm_src,
d_out,
gamma,
mean,
variance,
d_layernorm_src,
d_scale,
d_layernorm_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
this->ResidualDropoutBiasGrad(
ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias);
}
}
protected:
float epsilon_;
};
} // namespace fusion
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
namespace phi {
namespace fusion {
#define LN_NUM_COLS 1024
template <typename T>
using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
/**
* @brief fused add_bias, dropout, add residual and leyer_norm into one
* operators. Currently only support forward
*/
template <typename T,
int VecSize,
typename U,
bool ScaleBiasWithSameTypeX = false>
__device__ void CalcLayernormY(
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias,
const T *x,
T *y,
const int row_id,
const int col_id,
const int cols,
const LayerNormParamType<T> mean_val,
const LayerNormParamType<T> invvar) {
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using LoadU = phi::AlignedVector<U, VecSize>;
using LoadScaleOrBias =
phi::AlignedVector<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
LoadScaleOrBias scale_vec;
LoadScaleOrBias bias_vec;
LoadT x_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
scale_vec[ii] =
static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(1);
bias_vec[ii] =
static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(0);
}
// vectorize load data from global
phi::Load<T, VecSize>(&x[row_id * cols + i], &x_vec);
if (scale != nullptr) {
phi::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, VecSize>(
&scale[i], &scale_vec);
}
if (bias != nullptr) {
phi::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, VecSize>(
&bias[i], &bias_vec);
}
StoreT y_vec;
for (int ii = 0; ii < VecSize; ii++) {
y_vec[ii] =
static_cast<T>(static_cast<U>(scale_vec[ii]) *
(static_cast<U>(x_vec[ii]) - mean_val) * invvar +
static_cast<U>(bias_vec[ii]));
}
phi::Store<T, VecSize>(y_vec, &y[row_id * cols + i]);
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T,
typename MaskType,
int VecSize,
typename U,
bool ScaleBiasWithSameTypeX = false,
bool HasDropout = true>
__global__ void FusedLayernormResidualDropoutBias(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const uint64_t increment,
const float epsilon,
const T *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask,
T *dst,
T *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
if (HasDropout) {
curand_init(seed, idx, increment, &state);
}
T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
phi::funcs::ReluFunctor<T> relu;
U mean_val = 0;
U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
true,
false,
phi::funcs::ReluFunctor<T>,
T,
T,
HasDropout>(row_id,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
&mean_val,
&var_val,
relu);
}
mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
var_val = phi::funcs::BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) {
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols));
auto tmp = mean_val * static_cast<U>(scale);
mean[row_id] = mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * static_cast<U>(scale) -
mean_share * mean_share);
var_share = var_share > U(0) ? var_share : U(0);
var[row_id] = var_share;
}
__syncthreads();
mean_val = mean_share;
U invvar = phi::funcs::rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale,
layernorm_bias,
dst,
layernorm_dst,
row_id,
col_id,
cols,
mean_val,
invvar);
}
template <typename T,
typename MaskType,
int VecSize,
typename U,
bool ScaleBiasWithSameTypeX = false>
void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
int grid_dim,
int block_dim,
gpuStream_t stream,
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const uint64_t increment,
const float epsilon,
const T *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask,
T *dst,
T *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) {
if (dropout_prob != 0.0f) {
FusedLayernormResidualDropoutBias<T,
MaskType,
VecSize,
U,
ScaleBiasWithSameTypeX,
true>
<<<grid_dim, block_dim, 0, stream>>>(rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask,
dst,
layernorm_dst,
mean,
var);
} else {
FusedLayernormResidualDropoutBias<T,
MaskType,
VecSize,
U,
ScaleBiasWithSameTypeX,
false>
<<<grid_dim, block_dim, 0, stream>>>(rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask,
dst,
layernorm_dst,
mean,
var);
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
*/
template <typename T,
typename MaskType,
int VecSize,
typename U,
bool ScaleBiasWithSameTypeX = false>
__global__ void FusedLayernormResidualDropoutBiasInfer(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const uint64_t increment,
const float epsilon,
const T *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask,
T *dst,
T *layernorm_dst) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
phi::funcs::ReluFunctor<T> relu;
U mean_val = 0;
U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
true,
false,
phi::funcs::ReluFunctor<T>>(row_id,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
&mean_val,
&var_val,
relu);
}
mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
var_val = phi::funcs::BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) {
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols));
auto tmp = mean_val * static_cast<U>(scale);
mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * static_cast<U>(scale) -
mean_share * mean_share);
var_share = var_share > U(0) ? var_share : U(0);
}
__syncthreads();
mean_val = mean_share;
U invvar = phi::funcs::rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale,
layernorm_bias,
dst,
layernorm_dst,
row_id,
col_id,
cols,
mean_val,
invvar);
}
template <typename T,
typename MaskType,
int VecSize,
typename U,
bool ScaleBiasWithSameTypeX = false>
struct FusedLayernormResidualDropoutBiasFunctor {
void operator()(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const uint64_t increment,
const float epsilon,
const T *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask,
T *dst,
T *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var,
cudaStream_t stream) {
int blockDim = phi::funcs::GetDesiredBlockDim(cols / VecSize);
if (mean != nullptr && var != nullptr) {
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
MaskType,
VecSize,
U,
ScaleBiasWithSameTypeX>(
rows,
blockDim,
stream,
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask,
dst,
layernorm_dst,
mean,
var);
} else {
FusedLayernormResidualDropoutBiasInfer<T,
MaskType,
VecSize,
U,
ScaleBiasWithSameTypeX>
<<<rows, blockDim, 0, stream>>>(rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask,
dst,
layernorm_dst);
}
}
};
template struct FusedLayernormResidualDropoutBiasFunctor<phi::dtype::float16,
uint8_t,
8,
float,
false>;
/*
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is 768/1024/4096;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
* rows: batch_size * seq_len
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* bias_: [cols], linear bias, can be null
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
* residual_out_: [rows, cols], residual + dropout(src)
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
template <bool HasDropout,
typename T,
typename U,
typename ScaleT = U,
typename MaskType = uint8_t,
int VecSize = 8,
int WARPS_M = 4,
int WARPS_N = 1,
int BYTES_PER_LDG = 16,
int ELTS_PER_ROW = 1024,
int THREADS_PER_WARP = 32,
int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
typename InType = T,
typename OutType = T>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
int rows,
int cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const uint64_t increment,
const float epsilon,
const InType *__restrict__ x_ptr,
const T *__restrict__ residual_ptr,
const T *__restrict__ bias_ptr,
const ScaleT *__restrict__ gamma_ptr,
const ScaleT *__restrict__ beta_ptr,
MaskType *__restrict__ mask_out_ptr,
U *__restrict__ mean_out_ptr,
U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr,
OutType *__restrict__ y_ptr,
const float quant_last_in_scale = 1.0,
const float *__restrict__ quant_out_scale_ptr = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using Vec_in_type = phi::AlignedVector<InType, VecSize>;
using Vec_out_type = phi::AlignedVector<OutType, VecSize>;
using Vec_float = phi::AlignedVector<float, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31
const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3
const int warp_n = warp % WARPS_N; // 0
const int warp_m = warp / WARPS_N; // 0, 1, 2, 3
const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id
int idx = r * ELTS_PER_ROW + c;
curandStatePhilox4_32_10_t state;
if (HasDropout) {
curand_init(seed, idx, increment, &state);
}
T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
// bias
Vec bias[LDGS];
if (bias_ptr != nullptr) {
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(bias_ptr + col * VecSize, &bias[it]);
col += THREADS_PER_ROW;
}
}
Vec_scale gamma[LDGS];
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}
constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
Vec_in_type x_input[LDGS];
Vec residual[LDGS];
Vec_float dequant_out_scale[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(residual_ptr + row * ELTS_PER_ROW + col * VecSize,
&residual[it]);
phi::Load<InType, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x_input[it]);
if (quant_out_scale_ptr != nullptr) {
phi::Load<float, VecSize>(quant_out_scale_ptr + col * VecSize,
&dequant_out_scale[it]);
}
col += THREADS_PER_ROW;
}
MaskStoreT mask_vec[LDGS];
if (!is_test && HasDropout) {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
#pragma unroll
mask_vec[it][jt] = static_cast<MaskType>(rand[jt] >= dropout_prob);
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mask_vec[it][jt] = static_cast<MaskType>(1);
}
}
}
// 4 * 8
U xf[LDGS * VecSize];
if (bias_ptr != nullptr) {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
if (std::is_same<InType, int32_t>::value) {
T tmp = (static_cast<T>(static_cast<float>(x_input[it][jt]) *
dequant_out_scale[it][jt]) +
bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
x[it][jt] = tmp;
xf[it * VecSize + jt] = U(tmp);
} else {
x[it][jt] = (static_cast<T>(x_input[it][jt]) + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
if (std::is_same<InType, int32_t>::value) {
// for int32 input, we need to dequantize.
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
x[it][jt] = tmp;
} else {
x[it][jt] = static_cast<T>(x_input[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
}
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
}
// store dropout_residual_out and mask_out
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(
x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
if (!is_test && HasDropout) {
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
}
U mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mu_local += xf[it * VecSize + jt];
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx % THREADS_PER_ROW == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
}
U var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U diff = xf[it * VecSize + jt] - mu_local;
var_local += diff * diff;
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx % THREADS_PER_ROW == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
// Note: the stored var is different for paddle(ln) and apex (fast ln).
// var_out_ptr[row] = rsigma;
var_out_ptr[row] = var_local * rn;
}
Vec_out_type x_output[LDGS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U tmp = rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local);
x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
static_cast<U>(beta[it][jt]));
if (std::is_same<OutType, int8_t>::value)
x_output[it][jt] = phi::funcs::quant_helper(x[it][jt],
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
if (std::is_same<OutType, int8_t>::value) {
phi::Store<OutType, VecSize>(
x_output[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
} else {
phi::Store<T, VecSize>(
x[it],
reinterpret_cast<T *>(y_ptr) + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result, can be null if is_test = true
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T,
typename MaskType,
typename U,
bool ScaleBiasWithSameTypeX = false,
typename InType = T,
typename OutType = T>
void LaunchLayernormResidualDropoutBias(
const uint32_t rows,
const uint32_t cols,
const int increment,
uint64_t seed,
const float dropout_prob,
const float epsilon,
const bool is_upscale_in_train,
const bool is_test,
const InType *src,
const T *residual,
const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask_data,
T *dst,
OutType *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
// NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
auto cuda_place = ctx.GetPlace();
phi::memory_utils::Copy(cuda_place,
dst,
cuda_place,
residual,
rows * cols * sizeof(T),
ctx.stream());
if (mask_data != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
}
// call layernorm forward
switch (phi::funcs::GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE(
phi::funcs::LayerNormForward<T, U, kBlockDim, ScaleBiasWithSameTypeX>
<<<rows, kBlockDim, 0, ctx.stream()>>>(
dst,
scale,
layernorm_bias,
reinterpret_cast<T *>(layernorm_dst),
mean,
var,
epsilon,
cols));
default:
PADDLE_THROW(errors::InvalidArgument(
"Product from begin_norm_axis to end must be larger than 1"));
break;
}
return;
}
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \
const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \
const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
if (dropout_prob != 0.0f) { \
fused_fast_ln_fwd_kernel< \
true, \
T, \
U, \
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, \
uint8_t, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
cols, \
THREADS_PER_WARP, \
THREADS_PER_ROW, \
THREADS_PER_CTA, \
ROWS_PER_CTA, \
ELTS_PER_ROW_PER_CTA, \
LDGS, \
InType, \
OutType> \
<<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
is_test, \
increment, \
epsilon, \
src, \
residual, \
bias, \
scale, \
layernorm_bias, \
mask_data, \
mean, \
var, \
dst, \
layernorm_dst, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
} else { \
fused_fast_ln_fwd_kernel< \
false, \
T, \
U, \
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, \
uint8_t, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
cols, \
THREADS_PER_WARP, \
THREADS_PER_ROW, \
THREADS_PER_CTA, \
ROWS_PER_CTA, \
ELTS_PER_ROW_PER_CTA, \
LDGS, \
InType, \
OutType> \
<<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
is_test, \
increment, \
epsilon, \
src, \
residual, \
bias, \
scale, \
layernorm_bias, \
mask_data, \
mean, \
var, \
dst, \
layernorm_dst, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
} \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(3072); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool can_call_fast_ln_kernel = false;
if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 3072 ||
cols == 4096) &&
scale != nullptr && layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
}
VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
int blockDim = phi::funcs::GetDesiredBlockDim(cols);
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
uint8_t,
1,
U,
ScaleBiasWithSameTypeX>(
rows,
blockDim,
ctx.stream(),
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
reinterpret_cast<const T *>(src),
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
} else {
if (can_call_fast_ln_kernel) {
switch (cols) {
LAUNCH_FUSED_FAST_LN_KERNEL;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only when column is equal to 768/1024/4096 is supported for "
"now"));
break;
}
} else {
int blockDim = phi::funcs::GetDesiredBlockDim(cols / VecSize);
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
uint8_t,
VecSize,
U,
ScaleBiasWithSameTypeX>(
rows,
blockDim,
ctx.stream(),
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
reinterpret_cast<const T *>(src),
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
}
}
}
template <typename T,
typename U,
typename MaskType,
bool ScaleBiasWithSameTypeX = false>
void LaunchLayernormResidualDropoutGrad(
const phi::GPUContext &dev_ctx,
const uint32_t rows,
const uint32_t cols,
const float epsilon,
const float dropout_prob,
const bool is_upscale_in_train,
const T *d_out,
const T *layernorm_src,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormParamType<T> *mean,
const LayerNormParamType<T> *var,
const MaskType *mask_data,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_layernorm_bias,
T *d_residual,
T *d_dropout_src) {
const T zero = static_cast<T>(0.0f);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
phi::funcs::ln_bwd_fast_kernel_driver<
T,
U,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
MaskType>(dev_ctx,
rows,
cols,
epsilon,
layernorm_src,
scale,
mean,
var,
d_out,
d_residual,
d_scale,
d_layernorm_bias,
mask_data,
factor,
d_dropout_src);
}
} // namespace fusion
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include <cuda.h>
#endif
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_common.h"
namespace phi {
namespace fusion {
/**
* @brief The fused function called by every thread
* VecSize can be 1, 2, 4 or 8
*/
template <typename T,
typename MaskType,
int VecSize,
bool ComputeLayerNorm,
bool Activation,
typename Functor,
typename InType = T,
typename OutType = T,
bool HasDropout = true>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id,
const int col_id,
const int cols,
curandStatePhilox4_32_10_t *state,
const float dropout_prob,
const T factor,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
OutType *dst,
MaskType *mask,
const bool is_test,
typename phi::dtype::MPTypeTrait<T>::Type *mean_val,
typename phi::dtype::MPTypeTrait<T>::Type *var_val,
Functor act_func,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename phi::dtype::MPTypeTrait<T>::Type;
LoadInType src_vec;
LoadT residual_vec;
LoadT bias_vec;
LoadFloat quant_out_scale_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
residual_vec[ii] = static_cast<T>(0);
}
// vectorize load data from global
phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_id],
&quant_out_scale_vec);
if (residual) {
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
}
if (bias) {
phi::Load<T, VecSize>(&bias[col_id], &bias_vec);
}
MaskStoreT mask_vec;
if (!is_test && HasDropout) {
float rand[VecSize];
RandVec<VecSize>(state, rand);
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
}
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(1);
}
}
StoreT dest_vec;
StoreOutType dest_vec_out_type;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp;
if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii];
} else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
}
if (Activation) {
tmp = act_func(tmp);
}
if (HasDropout) {
dest_vec[ii] =
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
} else {
dest_vec[ii] = tmp * factor + residual_vec[ii];
}
if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]);
*mean_val += tmp;
*var_val += (tmp * tmp);
}
if (std::is_same<OutType, int8_t>::value) {
dest_vec_out_type[ii] = phi::funcs::quant_helper(dest_vec[ii],
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
// store result to global
if (std::is_same<OutType, int8_t>::value) {
phi::Store<OutType, VecSize>(dest_vec_out_type,
&dst[row_id * cols + col_id]);
} else {
phi::Store<T, VecSize>(dest_vec,
reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
}
if (!is_test && HasDropout) {
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 rows data by 8*VecSize warps
*/
template <typename T,
typename MaskType,
int BlockSizeX,
int BlockSizeY,
int VecSize,
bool HasDropout>
__global__ void FusedResidualDropoutBiasGrad(const T *dout,
const MaskType *mask,
const T factor,
const int64_t rows,
const int64_t cols,
T *dx,
T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
const bool not_need_dx = (dx == nullptr) || (dx == dout && !HasDropout &&
factor == static_cast<T>(1.0));
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
LoadT out_vec;
MaskLoadT mask_vec;
StoreT dx_vec;
phi::Load<T, VecSize>(&dout[index], &out_vec);
if (HasDropout) {
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
}
if (not_need_dx) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
tmp_sum[i] += out_vec[i];
}
} else {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
if (HasDropout) {
dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
} else {
dx_vec[i] = out_vec[i] * factor;
}
tmp_sum[i] += out_vec[i];
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
}
CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}
/*
* @brief calculate the grad of no bias
*/
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutGrad(const T *dout,
const MaskType *mask,
const T factor,
const int64_t size,
T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
MaskLoadT mask_vec;
phi::Load<T, VecSize>(&dout[i], &dout_vec);
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
}
phi::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
/**
* @brief dst = residual + dropout(src + bias);
* the src, residual, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
* is_test: only used in inference
* mask: can be null if is_test=true
*/
template <typename T,
typename MaskType,
int VecSize,
typename InType = T,
typename OutType = T,
bool HasDropout = true>
__global__ void FusedResidualDropoutBias(
const size_t rows,
const size_t cols,
uint64_t seed,
const float dropout_prob,
const bool is_upscale_in_train,
const InType *__restrict__ src,
const T *__restrict__ residual,
const T *__restrict__ bias,
MaskType *mask,
OutType *dst,
uint64_t increment,
const bool is_test,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
if (HasDropout) {
curand_init(seed, idx, increment, &state);
}
T factor;
if (HasDropout) {
factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
} else {
factor = static_cast<T>(1);
}
phi::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
false,
false,
phi::funcs::ReluFunctor<T>,
InType,
OutType,
HasDropout>(r,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
}
}
/**
* @brief dst = residual + dropout(src + bias);
*/
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
void LaunchResidualDropoutBias(const uint32_t rows,
const uint32_t cols,
const int increment,
uint64_t seed,
const float dropout_prob,
const bool is_test,
bool is_upscale_in_train,
const InType *src,
const T *residual,
const T *bias,
MaskType *mask_data,
OutType *dst,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
if (residual == dst) return;
if (residual) {
phi::memory_utils::Copy(ctx.GetPlace(),
dst,
ctx.GetPlace(),
residual,
rows * cols * sizeof(T),
ctx.stream());
} else {
SetZero<T>(ctx, dst, rows * cols);
}
if (!is_test) {
SetZero<MaskType>(ctx, mask_data, rows * cols);
}
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(__has_dropout) \
do { \
if (cols % VecSize == 0) { \
FusedResidualDropoutBias<T, \
uint8_t, \
VecSize, \
InType, \
OutType, \
__has_dropout> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} else { \
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(true);
} else {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(false);
}
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template <typename T, typename MaskType>
void LaunchResidualDropoutBiasGrad(const T *dout,
const MaskType *mask,
const float dropout_prob,
const bool is_upscale_in_train,
const uint32_t rows,
const uint32_t cols,
T *dx,
T *dbias,
const phi::GPUContext &ctx) {
const T zero = static_cast<T>(0.0f);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(__has_dropout) \
do { \
if (dbias != nullptr) { \
const auto threads = 8; \
auto blocks = std::max(static_cast<uint32_t>(1), \
(cols / real_vec_size + threads - 1) / threads); \
dim3 block_dim(threads, 128, 1); \
dim3 grid_dim(blocks, 1, 1); \
if (cols % VecSize == 0) { \
FusedResidualDropoutBiasGrad<T, \
MaskType, \
8, \
128, \
VecSize, \
__has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} else { \
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1, __has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} \
} else { \
if (dropout_prob == 0.0f) { \
if (dx == nullptr || dx == dout) { \
return; \
} \
phi::memory_utils::Copy(ctx.GetPlace(), \
dx, \
ctx.GetPlace(), \
dout, \
rows *cols * sizeof(T), \
ctx.stream()); \
} else { \
const uint64_t n = rows * cols; \
phi::backends::gpu::GpuLaunchConfig config = \
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n / real_vec_size); \
if (n % VecSize == 0) { \
FusedResidualDropoutGrad<T, MaskType, VecSize> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} else { \
FusedResidualDropoutGrad<T, MaskType, 1> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} \
} \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(true);
} else {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(false);
}
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL
}
} // namespace fusion
} // namespace phi
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include <hiprand_kernel.h> #include <hiprand_kernel.h>
#endif #endif
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -89,6 +91,174 @@ __device__ __forceinline__ void warp_reduce(T* sum) { ...@@ -89,6 +91,174 @@ __device__ __forceinline__ void warp_reduce(T* sum) {
} }
} }
#if defined(PADDLE_WITH_CUDA)
#define FINAL_MASK 0xffffffff
#define DIV_UP(x, y) (((x) + (y)-1) / (y))
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
inline int ElementsCeil(int seq_len) {
int elements = 1;
while (elements * 32 < seq_len) elements *= 2;
return elements;
}
template <typename T, int VEC_SIZE, int ELEMENTS_PER_THREADS>
__global__ void FusedSoftmaxMaskVecKernel(T* dst,
const T* src,
const T* mask,
int seq_len) {
constexpr int block_size = 128;
constexpr int warp_size = 32;
constexpr int warps_per_block = block_size / warp_size;
// blockDim/threadIdx = (warp_size, warps_per_block)
// gridDim/blockIdx = (DIV_UP(seq_len, warps_per_block), batch_size, head_num)
// every block processes 4(warps_per_block) sequences
// seq_id = seq_id * 4 + warp_id, eg.seq_len=128, 127=31*4+3
int seq_id = blockIdx.x * warps_per_block + threadIdx.y;
if (seq_id >= seq_len) return;
// ((bid*head_num + hid)*seq_len + seq_id) * seq_len
int offset =
((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len;
// (bid * seq_len + seq_id) * seq_len
int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len;
src += offset;
dst += offset;
mask += mask_offset;
static_assert(ELEMENTS_PER_THREADS % VEC_SIZE == 0, "");
constexpr int VEC_NUMS = ELEMENTS_PER_THREADS / VEC_SIZE;
using VecT = phi::AlignedVector<T, VEC_SIZE>;
VecT elements[VEC_NUMS];
VecT tmp_mask;
float max_val = -std::numeric_limits<float>::infinity();
for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) {
phi::Load(src + (i * warp_size + threadIdx.x) * VEC_SIZE, &elements[i]);
phi::Load(mask + (i * warp_size + threadIdx.x) * VEC_SIZE, &tmp_mask);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
// TODO(wangxi): vec add
elements[i][j] += tmp_mask[j];
max_val = max(max_val, static_cast<float>(elements[i][j]));
}
}
max_val = warpReduceMax(max_val);
float sum_val = 0;
for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) {
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
float tmp = __expf(static_cast<float>(elements[i][j]) - max_val);
sum_val += tmp;
elements[i][j] = static_cast<T>(tmp);
}
}
sum_val = warpReduceSum(sum_val);
float mean_val = __fdividef(1.0f, sum_val + 1e-6f);
for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) {
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
float tmp = static_cast<float>(elements[i][j]) * mean_val;
elements[i][j] = static_cast<T>(tmp);
}
phi::Store(elements[i], dst + (i * warp_size + threadIdx.x) * VEC_SIZE);
}
}
#define SOFTMAX_MASK_KERNEL(VEC_SIZE, ELEMENTS) \
FusedSoftmaxMaskVecKernel<T, VEC_SIZE, ELEMENTS> \
<<<grid, block, 0, stream>>>(dst, src, mask, seq_len);
#define SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS) \
do { \
if (seq_len % 2 == 0) { \
SOFTMAX_MASK_KERNEL(2, ELEMENTS); \
} else { \
SOFTMAX_MASK_KERNEL(1, ELEMENTS); \
} \
} while (0)
#define CASE_SOFTMAX_MASK_KERNEL(ELEMENTS) \
case ELEMENTS: { \
SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS); \
break; \
}
// template <typename T, typename MaskT = T>
template <typename T>
void LaunchFusedSoftmaxMaskKernel(const T* src,
const T* mask,
T* dst,
const int batch_size,
const int head_num,
const int seq_len,
cudaStream_t stream) {
PADDLE_ENFORCE_EQ(seq_len > 0 && seq_len <= 4096,
true,
errors::InvalidArgument("seq_len must be between (0, 4096] "
"received the seq_len is %d",
seq_len));
constexpr int block_size = 128;
constexpr int warp_size = 32;
constexpr int warps_per_block = block_size / warp_size;
// put head_num to the outside for mask
dim3 block(warp_size, warps_per_block);
dim3 grid(DIV_UP(seq_len, warps_per_block), batch_size, head_num);
int elements = ElementsCeil(seq_len);
switch (elements) {
case 1: { // <=32
SOFTMAX_MASK_KERNEL(1, 1);
break;
}
case 2: { // <=64
// if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 2);
// else SOFTMAX_MASK_KERNEL(1, 2);
SELECT_SOFTMAX_MASK_KERNEL(2);
break;
}
case 4: { // <=128
// if (seq_len % 4 == 0) SOFTMAX_MASK_KERNEL(4, 4);
// else if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 4);
// else SOFTMAX_MASK_KERNEL(1, 4);
SELECT_SOFTMAX_MASK_KERNEL(4);
break;
}
CASE_SOFTMAX_MASK_KERNEL(8); // <=256
CASE_SOFTMAX_MASK_KERNEL(16); // <=512
CASE_SOFTMAX_MASK_KERNEL(32); // <=1024
CASE_SOFTMAX_MASK_KERNEL(64); // <=2048
CASE_SOFTMAX_MASK_KERNEL(128); // <=4096
default:
PADDLE_THROW(errors::InvalidArgument(
"seq_len must be between (0, 4096], received the seq_len is %d",
seq_len));
}
}
#endif
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AttentionFuseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_attention",
{"X",
"LnScale",
"LnBias",
"QKVW",
"QKVBias",
"CacheKV",
"SrcMask",
"OutLinearW",
"OutLinearBias",
"Ln2Scale",
"Ln2Bias"},
{"num_heads",
"transpose_qkv_wb",
"pre_layer_norm",
"epsilon",
"attn_dropout_rate",
"is_test",
"attn_dropout_fix_seed",
"attn_dropout_seed",
"attn_dropout_implementation",
"dropout_rate",
"dropout_fix_seed",
"dropout_seed",
"dropout_implementation",
"ln_epsilon",
"add_residual",
"ring_id"},
{"LnMean", "LnVariance",
"LnOut", "QKVOut",
"QKVBiasOut", "TransposeOut2",
"QKOut", "QKTVOut",
"SoftmaxOut", "AttnDropoutMaskOut",
"AttnDropoutOut", "SrcMaskOut",
"FMHAOut", "OutLinearOut",
"DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut",
"CacheKVOut", "Y"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_attention,
phi::AttentionFuseOpArgumentMapping);
...@@ -28,8 +28,12 @@ from paddle.nn.layer.common import Dropout, Linear ...@@ -28,8 +28,12 @@ from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask from paddle.nn.layer.transformer import _convert_attention_mask
random.seed(42) seed = 42
default_main_program().random_seed = 42
random.seed(seed)
default_main_program().random_seed = seed
np.random.seed(seed)
paddle.seed(seed)
class TestFusedMultiTransformerOp(OpTest): class TestFusedMultiTransformerOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册