diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 9a14d35b59990b928094a84ed0d8ca65b64b683e..23cdc33658d1c1c0b71c34fffea5551c3c496b3d 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -38,6 +38,8 @@ if(WITH_XPU) op_library(resnet_basic_block_op) op_library(resnet_unit_op) op_library(fused_gemm_epilogue_op) + op_library(fused_attention_op) + op_library(fused_feedforward_op) endif() if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/operators/fused/fused_attention_op_xpu.cc b/paddle/fluid/operators/fused/fused_attention_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf2e3d80335f20c4fed3fa2c300be0838d8e052 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op_xpu.cc @@ -0,0 +1,939 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/fused/xpu_fused_common_function.h" +#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/xpu_api_wrapper.h" +#include "paddle/fluid/platform/device/device_wrapper.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +class FusedAttentionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using XPUTypeT = typename XPUTypeTrait::Type; + + // inputs tensor + auto *input_x = ctx.Input("X"); + + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + + // shape [3, num_head, dim_head, dim_embed] + auto *qkv_weight = ctx.Input("QKVW"); + // shape [3 , num_head, dim_head] + auto *qkv_bias = ctx.Input("QKVBias"); + + // shape [batch_size, 1, 1, seq_len] + auto *src_mask = ctx.Input("SrcMask"); + + // shape [dim_embed, dim_embed] + auto *out_linear_weight = ctx.Input("OutLinearW"); + // shape [dim_embed] + auto *out_linear_bias = ctx.Input("OutLinearBias"); + + const Tensor *ln_scale = nullptr; + const Tensor *ln_bias = nullptr; + float epsilon = 0.0f; + + if (pre_layer_norm) { + ln_scale = ctx.Input("LnScale"); + ln_bias = ctx.Input("LnBias"); + epsilon = ctx.Attr("epsilon"); + } else { + ln_scale = ctx.Input("Ln2Scale"); + ln_bias = ctx.Input("Ln2Bias"); + epsilon = ctx.Attr("ln_epsilon"); + } + + // outputs tensor + // qkv 的值,并已经做了transpos后的值 + // shape [3, batch_size, num_head, seq_len, dim_head] + auto *TransposeOut2 = ctx.Output("TransposeOut2"); + + // shape [batch_size, num_head, seq_len, seq_len] + auto *softmax_out = ctx.Output("SoftmaxOut"); + // shape [batch_size, num_head, seq_len, seq_len] + auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); + // shape [batch_size, num_head, seq_len, seq_len] + auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); + + // shape [[batch_size, seq_len, num_head, dim_head]] + auto *fmha_out = ctx.Output("FMHAOut"); + + // shape [batch_size, seq_len, dim_embed] + auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); + + // final output + // shape [batch_size, seq_len, dim_embed] + auto *out = ctx.Output("Y"); + + // 下面这个tensor是不需要返回, 但是新的动态图需要 + auto *QKOut = ctx.Output("QKOut"); + QKOut->mutable_data(ctx.GetPlace()); + auto *QKTVOut = ctx.Output("QKTVOut"); + QKTVOut->mutable_data(ctx.GetPlace()); + auto *OutLinearOut = ctx.Output("OutLinearOut"); + OutLinearOut->mutable_data(ctx.GetPlace()); + auto *QKVBiasOut = ctx.Output("QKVBiasOut"); + QKVBiasOut->mutable_data(ctx.GetPlace()); + auto *SrcMaskOut = ctx.Output("SrcMaskOut"); + SrcMaskOut->mutable_data(ctx.GetPlace()); + auto *qkv_out = ctx.Output("QKVOut"); + qkv_out->mutable_data(ctx.GetPlace()); + + Tensor *bias_dropout_residual_out = nullptr; + Tensor *ln_mean = nullptr; + Tensor *ln_var = nullptr; + Tensor *ln_out = nullptr; + + if (pre_layer_norm) { + ln_mean = ctx.Output("LnMean"); + ln_var = ctx.Output("LnVariance"); + ln_out = ctx.Output("LnOut"); + } else { + ln_mean = ctx.Output("Ln2Mean"); + ln_var = ctx.Output("Ln2Variance"); + bias_dropout_residual_out = ctx.Output("BiasDropoutResidualOut"); + } + + // dropout info + float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); + + bool is_test_1 = ctx.Attr("is_test"); + + auto &dropout_implementation_1 = + ctx.Attr("attn_dropout_implementation"); + + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + + int seed_val_1 = ctx.Attr("attn_dropout_seed"); + + XPUDropoutParam attn_dropout_param; + attn_dropout_param.initXPUDropoutParam(attn_dropout_rate, + is_upscale_in_train_1, + is_test_1, + is_fix_seed_1, + seed_1, + seed_val_1); + + XPUDropoutParam dropout_param(ctx, 0); + + // 先计算纬度 + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + int batch_size = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int embed_dims = input_x_dims[2]; + int num_heads = qkv_w_dims[1]; + int head_dims = qkv_w_dims[2]; + + // 输入指针 + const XPUTypeT *input_x_ptr = + reinterpret_cast(input_x->data()); + + const XPUTypeT *qkv_weight_ptr = + reinterpret_cast(qkv_weight->data()); + const XPUTypeT *qkv_bias_ptr = + reinterpret_cast(qkv_bias->data()); + const XPUTypeT *src_mask_ptr = + (src_mask == nullptr) + ? (nullptr) + : (reinterpret_cast(src_mask->data())); + + const XPUTypeT *out_linear_weight_ptr = + reinterpret_cast(out_linear_weight->data()); + + const XPUTypeT *out_linear_bias_ptr = + reinterpret_cast(out_linear_bias->data()); + + const float *ln_scale_ptr = + (ln_scale == nullptr) ? (nullptr) : ln_scale->data(); + + const float *ln_bias_ptr = + (ln_bias == nullptr) ? (nullptr) : ln_bias->data(); + + // 输出指针 + XPUTypeT *qkv_transpose_out_ptr = reinterpret_cast( + TransposeOut2->mutable_data(ctx.GetPlace())); + + XPUTypeT *softmax_out_ptr = reinterpret_cast( + softmax_out->mutable_data(ctx.GetPlace())); + + XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast( + attn_dropout_mask_out->mutable_data(ctx.GetPlace())); + + XPUTypeT *attn_dropout_out_ptr = reinterpret_cast( + attn_dropout_out->mutable_data(ctx.GetPlace())); + + XPUTypeT *fmha_out_ptr = + reinterpret_cast(fmha_out->mutable_data(ctx.GetPlace())); + + XPUTypeT *dropout_mask_out_ptr = reinterpret_cast( + dropout_mask_out->mutable_data(ctx.GetPlace())); + + XPUTypeT *out_ptr = + reinterpret_cast(out->mutable_data(ctx.GetPlace())); + + XPUTypeT *bias_dropout_residual_out_ptr = + (bias_dropout_residual_out == nullptr) + ? (nullptr) + : (reinterpret_cast( + bias_dropout_residual_out->mutable_data(ctx.GetPlace()))); + + float *ln_mean_ptr = (ln_mean == nullptr) + ? (nullptr) + : ln_mean->mutable_data(ctx.GetPlace()); + + float *ln_var_ptr = (ln_var == nullptr) + ? (nullptr) + : ln_var->mutable_data(ctx.GetPlace()); + + XPUTypeT *ln_out_ptr = (ln_out == nullptr) + ? (nullptr) + : (reinterpret_cast( + ln_out->mutable_data(ctx.GetPlace()))); + + auto &dev_ctx = ctx.template device_context(); + + xpu::Context *xpu_ctx = dev_ctx.x_context(); + + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int l3_total_size = xpu_ctx->_l3_mgr.get_size(); + + XPUTypeT *qkv_before_transpos_ptr = + NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims] + XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len] + XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims] + XPUTypeT *linear_out_ptr = + NULL; // x4, x5 [batch_size, seq_len, embed_dims] + + int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims; + int temp_size_2 = batch_size * num_heads * seq_len * seq_len; + int temp_size_3 = batch_size * num_heads * seq_len * head_dims; + int temp_size_4 = batch_size * seq_len * embed_dims; + + std::vector temp_vec = { + temp_size_1, temp_size_2, temp_size_3, temp_size_4}; + std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); + XPUTypeT *max_gm_ptr = RAII_GUARD.alloc(temp_vec[0]); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr); + qkv_before_transpos_ptr = max_gm_ptr; + qk_ptr = max_gm_ptr; + qkv_ptr = max_gm_ptr; + linear_out_ptr = max_gm_ptr; + int sizeof_t = sizeof(XPUTypeT); + for (size_t i = 0; i < temp_vec.size(); ++i) { + if (l3_total_size >= temp_vec[i] * sizeof_t) { + XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3(temp_vec[i]); + qkv_before_transpos_ptr = + (temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + break; + } + } + + int r = 0; + const XPUTypeT *x_cacl_ptr = input_x_ptr; + if (pre_layer_norm) { + r = xpu::layer_norm(xpu_ctx, + input_x_ptr, + ln_out_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_var_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + x_cacl_ptr = ln_out_ptr; + } + + // fc + phi::XpuFcInfo qkv_fc_info; + qkv_fc_info.InitFcInfo(0, + batch_size * seq_len, + 3 * num_heads * head_dims, + embed_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + phi::MatMulXPUFunction(xpu_ctx, + x_cacl_ptr, + qkv_weight_ptr, + qkv_before_transpos_ptr, + qkv_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + qkv_before_transpos_ptr, + qkv_bias_ptr, + qkv_before_transpos_ptr, + {batch_size * seq_len, 3 * num_heads * head_dims}, + {3 * num_heads * head_dims}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // transpose + r = xpu::transpose(xpu_ctx, + qkv_before_transpos_ptr, + qkv_transpose_out_ptr, + {batch_size, seq_len, 3, num_heads, head_dims}, + {2, 0, 3, 1, 4}); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + int qkv_every_size = batch_size * seq_len * num_heads * head_dims; + { + float alpha = 1.0 / sqrt(head_dims); + r = scale(xpu_ctx, + qkv_transpose_out_ptr, + qkv_transpose_out_ptr, + qkv_every_size, + false, + alpha, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } + + // begin fhma + // 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos + { + const XPUTypeT *q_ptr = qkv_transpose_out_ptr; + const XPUTypeT *k_ptr = q_ptr + qkv_every_size; + const XPUTypeT *v_ptr = k_ptr + qkv_every_size; + phi::XpuFcInfo qk_fc_info; + qk_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + seq_len, + head_dims, + false, + true, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction( + xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f); + + if (src_mask_ptr) { + r = xpu::broadcast_add(xpu_ctx, + qk_ptr, + src_mask_ptr, + qk_ptr, + {batch_size, num_heads, seq_len, seq_len}, + {batch_size, 1, 1, seq_len}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + } + // do softmax + r = xpu::softmax(xpu_ctx, + qk_ptr, + softmax_out_ptr, + {batch_size, num_heads, seq_len, seq_len}, + 3); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); + + // do dropout + Dropout(xpu_ctx, + softmax_out_ptr, + attn_dropout_mask_out_ptr, + attn_dropout_out_ptr, + attn_dropout_param, + batch_size * num_heads * seq_len * seq_len); + + phi::XpuFcInfo qktv_fc_info; + qktv_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + head_dims, + seq_len, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction( + xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f); + r = xpu::transpose(xpu_ctx, + qkv_ptr, + fmha_out_ptr, + {batch_size, num_heads, seq_len, head_dims}, + {0, 2, 1, 3}); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + } + + // linear_out + phi::XpuFcInfo linear_fc_info; + linear_fc_info.InitFcInfo(0, + batch_size * seq_len, + embed_dims, + embed_dims, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + fmha_out_ptr, + out_linear_weight_ptr, + linear_out_ptr, + linear_fc_info, + 1.0f); + + // out_linear_bias_ptr + r = xpu::broadcast_add(xpu_ctx, + linear_out_ptr, + out_linear_bias_ptr, + linear_out_ptr, + {batch_size * seq_len, embed_dims}, + {embed_dims}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + Dropout(xpu_ctx, + linear_out_ptr, + dropout_mask_out_ptr, + linear_out_ptr, + dropout_param, + batch_size * seq_len * embed_dims); + + XPUTypeT *real_out_ptr = out_ptr; + if (pre_layer_norm == false) { + real_out_ptr = bias_dropout_residual_out_ptr; + } + + r = xpu::add(xpu_ctx, + linear_out_ptr, + input_x_ptr, + real_out_ptr, + batch_size * seq_len * embed_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + + if (pre_layer_norm == false) { + r = xpu::layer_norm(xpu_ctx, + real_out_ptr, + out_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_var_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + } + } +}; + +// template +template +class FusedAttentionGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using XPUTypeT = typename XPUTypeTrait::Type; + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + + // dropout info + float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); + bool is_test_1 = ctx.Attr("is_test"); + auto &dropout_implementation_1 = + ctx.Attr("attn_dropout_implementation"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + int seed_val_1 = ctx.Attr("attn_dropout_seed"); + + XPUDropoutParam attn_dropout_param; + attn_dropout_param.initXPUDropoutParam(attn_dropout_prob, + is_upscale_in_train_1, + is_test_1, + is_fix_seed_1, + seed_1, + seed_val_1); + + XPUDropoutParam dropout_param(ctx, 0); + // get inputs. + auto *d_y = ctx.Input(framework::GradVarName("Y")); + const XPUTypeT *d_y_ptr = + reinterpret_cast(d_y->data()); + // 前向必要参数 + auto *input_x = ctx.Input("X"); + const XPUTypeT *input_x_ptr = + reinterpret_cast(input_x->data()); + auto *qkv_transpose_out = ctx.Input("TransposeOut2"); + const XPUTypeT *qkv_transpose_out_ptr = + reinterpret_cast(qkv_transpose_out->data()); + auto *qkv_weight = ctx.Input("QKVW"); + const XPUTypeT *qkv_weight_ptr = + reinterpret_cast(qkv_weight->data()); + + auto *softmax_out = ctx.Input("SoftmaxOut"); + const XPUTypeT *softmax_out_ptr = + reinterpret_cast(softmax_out->data()); + auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); + const XPUTypeT *attn_dropout_out_ptr = + reinterpret_cast(attn_dropout_out->data()); + + auto *attn_dropout_mask = ctx.Input("AttnDropoutMaskOut"); + const XPUTypeT *attn_dropout_mask_ptr = + reinterpret_cast(attn_dropout_mask->data()); + auto *fmha_out = ctx.Input("FMHAOut"); + const XPUTypeT *fmha_out_ptr = + reinterpret_cast(fmha_out->data()); + + auto *out_linear_weight = ctx.Input("OutLinearW"); + const XPUTypeT *out_linear_weight_ptr = + reinterpret_cast(out_linear_weight->data()); + + auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); + const XPUTypeT *dropout_mask_out_ptr = + reinterpret_cast(dropout_mask_out->data()); + // 需要计算的梯度 + auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); + XPUTypeT *d_qkv_weight_ptr = reinterpret_cast( + d_qkv_weight->mutable_data(ctx.GetPlace())); + + auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); + XPUTypeT *d_qkv_bias_ptr = reinterpret_cast( + d_qkv_bias->mutable_data(ctx.GetPlace())); + auto *d_out_linear_weight = + ctx.Output(framework::GradVarName("OutLinearW")); + + XPUTypeT *d_out_linear_weight_ptr = reinterpret_cast( + d_out_linear_weight->mutable_data(ctx.GetPlace())); + + auto *d_out_linear_bias = + ctx.Output(framework::GradVarName("OutLinearBias")); + XPUTypeT *d_out_linear_bias_ptr = reinterpret_cast( + d_out_linear_bias->mutable_data(ctx.GetPlace())); + // 有可能需要 + auto *d_src_mask_out = + ctx.Output(framework::GradVarName("SrcMaskOut")); + XPUTypeT *d_src_mask_out_ptr = + (d_src_mask_out == nullptr) + ? (nullptr) + : (reinterpret_cast( + d_src_mask_out->mutable_data(ctx.GetPlace()))); + // 输出 dx + auto *d_x = ctx.Output(framework::GradVarName("X")); + XPUTypeT *d_x_ptr = + reinterpret_cast(d_x->mutable_data(ctx.GetPlace())); + + const Tensor *ln_out = nullptr; + const Tensor *bias_dropout_residual_out = nullptr; + const Tensor *ln_scale = nullptr; + const Tensor *ln_mean = nullptr; + const Tensor *ln_var = nullptr; + Tensor *d_ln_scale = nullptr; + Tensor *d_ln_bias = nullptr; + + const XPUTypeT *ln_out_ptr = NULL; + const float *ln_scale_ptr = NULL; + const float *ln_mean_ptr = NULL; + const float *ln_var_ptr = NULL; + const XPUTypeT *bias_dropout_residual_out_ptr = NULL; + float *d_ln_scale_ptr = nullptr; + float *d_ln_bias_ptr = nullptr; + + float epsilon = 0.0f; + + if (pre_layer_norm) { + ln_out = ctx.Input("LnOut"); + ln_out_ptr = reinterpret_cast(ln_out->data()); + ln_scale = ctx.Input("LnScale"); + ln_mean = ctx.Input("LnMean"); + ln_var = ctx.Input("LnVariance"); + epsilon = ctx.Attr("epsilon"); + d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); + d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); + + } else { + ln_scale = ctx.Input("Ln2Scale"); + ln_mean = ctx.Input("Ln2Mean"); + ln_var = ctx.Input("Ln2Variance"); + epsilon = ctx.Attr("ln_epsilon"); + d_ln_scale = ctx.Output(framework::GradVarName("Ln2Scale")); + d_ln_bias = ctx.Output(framework::GradVarName("Ln2Bias")); + bias_dropout_residual_out = ctx.Input("BiasDropoutResidualOut"); + bias_dropout_residual_out_ptr = reinterpret_cast( + bias_dropout_residual_out->data()); + } + + ln_scale_ptr = ln_scale->data(); + ln_mean_ptr = ln_mean->data(); + ln_var_ptr = ln_var->data(); + d_ln_scale_ptr = d_ln_scale->mutable_data(ctx.GetPlace()); + d_ln_bias_ptr = d_ln_bias->mutable_data(ctx.GetPlace()); + + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + int batch_size = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int embed_dims = input_x_dims[2]; + int num_heads = qkv_w_dims[1]; + int head_dims = qkv_w_dims[2]; + + auto &dev_ctx = ctx.template device_context(); + xpu::Context *xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int r = 0; + // int l3_total_size = xpu_ctx->_l3_mgr.get_size(); + XPUTypeT *d_ln_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] + XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] + + XPUTypeT *d_fmha_out_ptr = + NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims] + XPUTypeT *d_fmha_out_transpos_tmp_ptr = + NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads, + // head_dims] + + XPUTypeT *d_qk_ptr = + NULL; // d_qk_ptr[batch_size, num_heads, seq_len, seq_len] + + XPUTypeT *d_combination_qkv_ptr = + NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len, + // head_dims] + XPUTypeT *d_transpos_qkv_ptr = + NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims] + + XPUTypeT *d_last_layernorm_grad_ptr = + NULL; // d_layer_out [batch_size, seq_len, embed_dims] + + const XPUTypeT *dy_input_ptr = d_y_ptr; + + d_ln_grad_ptr = + RAII_GUARD.alloc(batch_size * seq_len * embed_dims); + d_dropout_grad_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + d_fmha_out_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * + num_heads * head_dims); + d_combination_qkv_ptr = + RAII_GUARD.alloc(batch_size * seq_len * embed_dims * 3); + d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm( + batch_size * seq_len * embed_dims * 3); + d_fmha_out_transpos_tmp_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + d_qk_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * + seq_len * num_heads); + d_last_layernorm_grad_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + + if (pre_layer_norm == false) { + r = xpu::layer_norm_grad(xpu_ctx, + bias_dropout_residual_out_ptr, + d_y_ptr, + d_ln_grad_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_var_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + dy_input_ptr = d_ln_grad_ptr; + } + // dropout_grad + DropoutGrad(xpu_ctx, + dy_input_ptr, + dropout_mask_out_ptr, + d_dropout_grad_ptr, + dropout_param, + batch_size * num_heads * seq_len * head_dims); + + // linear_out + phi::XpuFcInfo linear_fc_info; + linear_fc_info.InitFcInfo(0, + batch_size * seq_len, + embed_dims, + embed_dims, + false, + false, + nullptr, + nullptr, + nullptr); + const XPUTypeT *a_1 = reinterpret_cast(NULL); + const XPUTypeT *b_1 = reinterpret_cast(NULL); + const XPUTypeT *a_2 = reinterpret_cast(NULL); + const XPUTypeT *b_2 = reinterpret_cast(NULL); + + XPUTypeT *c_1 = d_fmha_out_ptr; + XPUTypeT *c_2 = d_out_linear_weight_ptr; + phi::XpuFcInfo info_dfmha; + phi::XpuFcInfo info_dlinear_w; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear_fc_info, + false, + false, + fmha_out_ptr, + out_linear_weight_ptr, + d_dropout_grad_ptr); + + std::tie(info_dfmha, info_dlinear_w, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dlinear_w, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_dfmha, 1.0f, true); + + // dlinear_bias + r = xpu::reduce_sum(xpu_ctx, + d_dropout_grad_ptr, + d_out_linear_bias_ptr, + {batch_size * seq_len, embed_dims}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + { + int qkv_size = batch_size * seq_len * num_heads * head_dims; + const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr; + const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size; + const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size; + XPUTypeT *d_q_out_ptr = d_combination_qkv_ptr; + XPUTypeT *d_k_out_ptr = d_q_out_ptr + qkv_size; + XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size; + r = xpu::transpose(xpu_ctx, + d_fmha_out_ptr, + d_fmha_out_transpos_tmp_ptr, + {batch_size, seq_len, num_heads, head_dims}, + {0, 2, 1, 3}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + phi::XpuFcInfo qktv_fc_info; + qktv_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + head_dims, + seq_len, + false, + false, + nullptr, + nullptr, + nullptr); + + const XPUTypeT *a_1 = reinterpret_cast(NULL); + const XPUTypeT *b_1 = reinterpret_cast(NULL); + const XPUTypeT *a_2 = reinterpret_cast(NULL); + const XPUTypeT *b_2 = reinterpret_cast(NULL); + XPUTypeT *c_1 = d_qk_ptr; + XPUTypeT *c_2 = d_v_out_ptr; + phi::XpuFcInfo info_d_qk; + phi::XpuFcInfo info_d_v; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qktv_fc_info, + false, + false, + attn_dropout_out_ptr, + v_out_ptr, + d_fmha_out_transpos_tmp_ptr); + + std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_qk, 1.0f, true); + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_v, 1.0f, true); + + DropoutGrad(xpu_ctx, + d_qk_ptr, + attn_dropout_mask_ptr, + d_qk_ptr, + attn_dropout_param, + batch_size * seq_len * seq_len * num_heads); + + r = xpu::softmax_grad(xpu_ctx, + softmax_out_ptr, + d_qk_ptr, + d_qk_ptr, + {batch_size, num_heads, seq_len, seq_len}, + 3); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); + + if (d_src_mask_out_ptr) { + r = xpu::copy(xpu_ctx, + d_qk_ptr, + d_src_mask_out_ptr, + batch_size * seq_len * seq_len * num_heads); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + } + phi::XpuFcInfo qk_fc_info; + qk_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + seq_len, + head_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + c_1 = d_q_out_ptr; + c_2 = d_k_out_ptr; + phi::XpuFcInfo info_d_q; + phi::XpuFcInfo info_d_k; + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qk_fc_info, + false, + true, + q_out_ptr, + k_out_ptr, + d_qk_ptr); + + std::tie(info_d_q, info_d_k, a_1, b_1, a_2, b_2) = fc_info; + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_q, 1.0f / sqrt(head_dims), true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_k, 1.0f, true); + } + + // + r = xpu::transpose(xpu_ctx, + d_combination_qkv_ptr, + d_transpos_qkv_ptr, + {3, batch_size, num_heads, seq_len, head_dims}, + {1, 3, 0, 2, 4}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + // dx and d_qkv_w + phi::XpuFcInfo qkv_fc_info; + qkv_fc_info.InitFcInfo(0, + batch_size * seq_len, + 3 * num_heads * head_dims, + embed_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + c_1 = (pre_layer_norm == true) ? d_last_layernorm_grad_ptr : d_x_ptr; + c_2 = d_qkv_weight_ptr; + phi::XpuFcInfo info_d_x; + phi::XpuFcInfo info_d_qkv_w; + + const XPUTypeT *use_calc_input_x_ptr = + (pre_layer_norm == true) ? ln_out_ptr : input_x_ptr; + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qkv_fc_info, + false, + true, + use_calc_input_x_ptr, + qkv_weight_ptr, + d_transpos_qkv_ptr); + + std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_x, 1.0f, true); + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_qkv_w, 1.0f, true); + + // d_qkv_bias + r = xpu::reduce_sum(xpu_ctx, + d_transpos_qkv_ptr, + d_qkv_bias_ptr, + {batch_size * seq_len, 3 * embed_dims}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + if (pre_layer_norm) { + r = xpu::layer_norm_grad(xpu_ctx, + input_x_ptr, + c_1, + d_x_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_var_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + } + + // add rediaus dy + r = xpu::add(xpu_ctx, + dy_input_ptr, + d_x_ptr, + d_x_ptr, + batch_size * seq_len * embed_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL( + fused_attention, + ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel); + +REGISTER_OP_XPU_KERNEL( + fused_attention_grad, + ops::FusedAttentionGradXPUKernel, + ops::FusedAttentionGradXPUKernel); + +#endif diff --git a/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc b/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b94d37a921fb629f12dca94316d909f21be04eb5 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc @@ -0,0 +1,828 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/xpu_api_wrapper.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +#include "paddle/fluid/operators/fused/xpu_fused_common_function.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +class FusedFeedForwardXPUKernel : public framework::OpKernel { + using XPUTypeT = typename XPUTypeTrait::Type; + + public: + void FFN(const phi::XPUContext& dev_ctx, + const Tensor* x, + const Tensor* linear1_weight, + const Tensor* linear1_bias, + const Tensor* linear2_weight, + const Tensor* linear2_bias, + const Tensor* ln_scale, + const Tensor* ln_bias, + Tensor* out, + Tensor* dropout1_mask, + Tensor* dropout2_mask, + Tensor* ln_mean, + Tensor* ln_variance, + Tensor* linear1_out, + Tensor* ln1_out, + Tensor* dropout1_out, + Tensor* dropout2_out, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon1, + const float epsilon2, + const XPUDropoutParam& dropout_param1, + const XPUDropoutParam& dropout_param2, + int ring_id) const { + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int r = xpu::SUCCESS; + + const XPUTypeT* x_ptr = reinterpret_cast(x->data()); + const XPUTypeT* residual_ptr = x_ptr; + const XPUTypeT* linear1_weight_ptr = + reinterpret_cast(linear1_weight->data()); + const XPUTypeT* linear1_bias_ptr = + reinterpret_cast(linear1_bias->data()); + const XPUTypeT* linear2_weight_ptr = + reinterpret_cast(linear2_weight->data()); + const XPUTypeT* linear2_bias_ptr = + reinterpret_cast(linear2_bias->data()); + + const float* ln_scale_ptr = ln_scale->data(); + + const float* ln_bias_ptr = ln_bias->data(); + + // out + XPUTypeT* out_ptr = reinterpret_cast(out->data()); + XPUTypeT* linear1_out_ptr = + reinterpret_cast(linear1_out->data()); + XPUTypeT* dropout1_mask_ptr = + reinterpret_cast(dropout1_mask->data()); + XPUTypeT* dropout2_mask_ptr = + reinterpret_cast(dropout2_mask->data()); + float* ln_mean_ptr = ln_mean->data(); + float* ln_variance_ptr = ln_variance->data(); + + XPUTypeT* dropout1_out_ptr = + reinterpret_cast(dropout1_out->data()); + XPUTypeT* dropout2_out_ptr = + reinterpret_cast(dropout2_out->data()); + + size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); + XPUTypeT* linear2_before_tmp_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* linear2_after_tmp_ptr = NULL; // d_model * bsz_seq + if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { + XPUTypeT* l3_ptr = + RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); + linear2_before_tmp_ptr = linear2_after_tmp_ptr = l3_ptr; + } else if ((l3_total_size < dim_feedforward * bsz_seq * sizeof(T)) && + (l3_total_size >= d_model * bsz_seq * sizeof(T))) { + XPUTypeT* l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); + linear2_after_tmp_ptr = l3_ptr; + linear2_before_tmp_ptr = + RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(linear2_before_tmp_ptr); + + } else { + XPUTypeT* gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(gm_ptr); + linear2_before_tmp_ptr = linear2_after_tmp_ptr = gm_ptr; + } + + // layernorm + if (pre_layer_norm) { + XPUTypeT* ln1_out_ptr = reinterpret_cast(ln1_out->data()); + r = xpu::layer_norm(xpu_ctx, + x_ptr, + ln1_out_ptr, + bsz_seq, + d_model, + epsilon1, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_variance_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm "); + x_ptr = ln1_out_ptr; + } + + // fc + phi::XpuFcInfo linear1_fc_info; + linear1_fc_info.InitFcInfo(0, + bsz_seq, + dim_feedforward, + d_model, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + x_ptr, + linear1_weight_ptr, + linear2_before_tmp_ptr, + linear1_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + linear2_before_tmp_ptr, + linear1_bias_ptr, + linear1_out_ptr, + {bsz_seq, dim_feedforward}, + {dim_feedforward}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // act + if (act_method == "gelu") { + r = xpu::gelu(xpu_ctx, + linear1_out_ptr, + linear2_before_tmp_ptr, + linear1_out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); + } else if (act_method == "relu") { + r = xpu::relu(xpu_ctx, + linear1_out_ptr, + linear2_before_tmp_ptr, + linear1_out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently only supports gelu or relu activation functions!")); + } + + // dropout1 + Dropout(xpu_ctx, + linear2_before_tmp_ptr, + dropout1_mask_ptr, + dropout1_out_ptr, + dropout_param1, + dropout1_out->numel()); + + // fc + phi::XpuFcInfo linear2_fc_info; + linear2_fc_info.InitFcInfo(0, + bsz_seq, + d_model, + dim_feedforward, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + dropout1_out_ptr, + linear2_weight_ptr, + dropout2_out_ptr, + linear2_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + dropout2_out_ptr, + linear2_bias_ptr, + dropout2_out_ptr, + {bsz_seq, d_model}, + {d_model}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // dropout2 + Dropout(xpu_ctx, + dropout2_out_ptr, + dropout2_mask_ptr, + dropout2_out_ptr, + dropout_param2, + dropout2_out->numel()); + + // residual_ptr + dropout_out + XPUTypeT* residual_add_out_ptr = out_ptr; + if (pre_layer_norm == false) { + residual_add_out_ptr = dropout2_out_ptr; + } + r = xpu::broadcast_add(xpu_ctx, + residual_ptr, + dropout2_out_ptr, + residual_add_out_ptr, + {bsz_seq, d_model}, + {bsz_seq, d_model}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + if (pre_layer_norm == false) { + r = xpu::layer_norm(xpu_ctx, + residual_add_out_ptr, + out_ptr, + bsz_seq, + d_model, + epsilon2, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_variance_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto place = context.GetPlace(); + + auto* x = context.Input("X"); + + auto* linear1_weight = context.Input("Linear1Weight"); + auto* linear1_bias = context.Input("Linear1Bias"); + auto* linear2_weight = context.Input("Linear2Weight"); + auto* linear2_bias = context.Input("Linear2Bias"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + + const Tensor* ln_scale = nullptr; + const Tensor* ln_bias = nullptr; + Tensor* ln_mean = nullptr; + Tensor* ln_variance = nullptr; + Tensor* ln1_out = nullptr; + + if (pre_layer_norm) { + ln_scale = context.Input("Ln1Scale"); + ln_bias = context.Input("Ln1Bias"); + ln_mean = context.Output("Ln1Mean"); + ln_variance = context.Output("Ln1Variance"); + ln1_out = context.Output("Ln1Out"); + ln1_out->mutable_data(place); + } else { + ln_scale = context.Input("Ln2Scale"); + ln_bias = context.Input("Ln2Bias"); + ln_mean = context.Output("Ln2Mean"); + ln_variance = context.Output("Ln2Variance"); + } + + auto* out = context.Output("Out"); + auto* dropout1_mask = context.Output("Dropout1Mask"); + auto* dropout2_mask = context.Output("Dropout2Mask"); + auto* linear1_out = context.Output("Linear1Out"); + + auto* dropout1_out = context.Output("Dropout1Out"); + auto* dropout2_out = context.Output("Dropout2Out"); + + const std::string act_method = context.Attr("act_method"); + + const int ring_id = context.Attr("ring_id"); + const float epsilon1 = context.Attr("ln1_epsilon"); + const float epsilon2 = context.Attr("ln2_epsilon"); + XPUDropoutParam dropout_param1; + dropout_param1.initXPUDropoutParam(context, 1); + XPUDropoutParam dropout_param2; + dropout_param2.initXPUDropoutParam(context, 2); + + ln_mean->mutable_data(place); + ln_variance->mutable_data(place); + out->mutable_data(place); + dropout1_mask->mutable_data(place); + dropout2_mask->mutable_data(place); + dropout1_out->mutable_data(place); + dropout2_out->mutable_data(place); + linear1_out->mutable_data(place); + + auto x_dim = x->dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dim), 0, false); + + auto dim = linear1_weight->dims(); + int d_model = dim[0]; + int dim_feedforward = dim[dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + auto& dev_ctx = context.template device_context(); + FFN(dev_ctx, + x, + linear1_weight, + linear1_bias, + linear2_weight, + linear2_bias, + ln_scale, + ln_bias, + out, + dropout1_mask, + dropout2_mask, + ln_mean, + ln_variance, + linear1_out, + ln1_out, + dropout1_out, + dropout2_out, + bsz_seq, + d_model, + dim_feedforward, + act_method, + pre_layer_norm, + epsilon1, + epsilon2, + dropout_param1, + dropout_param2, + ring_id); + } +}; + +template +class FusedFeedForwardGradXPUKernel : public framework::OpKernel { + using XPUTypeT = typename XPUTypeTrait::Type; + + public: + void FFNGrad(const phi::XPUContext& dev_ctx, + const Tensor* d_out, + const Tensor* x, + const Tensor* dropout1_mask, + const Tensor* dropout2_mask, + const Tensor* linear1_out, + const Tensor* ln1_out, + const Tensor* dropout1_out, + const Tensor* dropout2_out, + const Tensor* linear1_weight, + const Tensor* linear2_weight, + const Tensor* ln_scale, + const Tensor* ln_mean, + const Tensor* ln_variance, + Tensor* d_x, + Tensor* d_linear1_weight, + Tensor* d_linear1_bias, + Tensor* d_linear2_weight, + Tensor* d_linear2_bias, + Tensor* d_ln_scale, + Tensor* d_ln_bias, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const XPUDropoutParam& dropout_param1, + const XPUDropoutParam& dropout_param2, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon, + const int ring_id) const { + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + int r = xpu::SUCCESS; + + // inputs ptr + const XPUTypeT* d_out_ptr = + reinterpret_cast(d_out->data()); + const XPUTypeT* x_ptr = reinterpret_cast(x->data()); + const XPUTypeT* dropout1_mask_ptr = + reinterpret_cast(dropout1_mask->data()); + const XPUTypeT* dropout2_mask_ptr = + reinterpret_cast(dropout2_mask->data()); + const XPUTypeT* linear1_out_ptr = + reinterpret_cast(linear1_out->data()); + const XPUTypeT* dropout1_out_ptr = + reinterpret_cast(dropout1_out->data()); + const XPUTypeT* linear1_weight_ptr = + reinterpret_cast(linear1_weight->data()); + const XPUTypeT* linear2_weight_ptr = + reinterpret_cast(linear2_weight->data()); + const float* ln_scale_ptr = ln_scale->data(); + + const float* ln_mean_ptr = ln_mean->data(); + const float* ln_variance_ptr = ln_variance->data(); + // outputs ptr + XPUTypeT* d_x_ptr = reinterpret_cast(d_x->data()); + XPUTypeT* d_linear1_weight_ptr = + reinterpret_cast(d_linear1_weight->data()); + XPUTypeT* d_linear1_bias_ptr = + reinterpret_cast(d_linear1_bias->data()); + XPUTypeT* d_linear2_weight_ptr = + reinterpret_cast(d_linear2_weight->data()); + XPUTypeT* d_linear2_bias_ptr = + reinterpret_cast(d_linear2_bias->data()); + float* d_ln_scale_ptr = d_ln_scale->data(); + float* d_ln_bias_ptr = d_ln_bias->data(); + + size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); + + XPUTypeT* big_tmp_l3_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* small_tmp_l3_ptr = NULL; // d_model * bsz_seq + XPUTypeT* big_tmp_gm_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* small_tmp_gm_ptr = NULL; // d_model * bsz_seq + + XPUTypeT* d_layernorm_out_ptr = NULL; // dx9 + XPUTypeT* d_dropout2_out_ptr = NULL; // dx7 + + XPUTypeT* d_linear2_out_ptr = NULL; // dx5 + XPUTypeT* d_dropout1_out_ptr = NULL; // dx4 + XPUTypeT* d_act_out_ptr = NULL; // dx3 + + XPUTypeT* d_linear1_out_ptr = NULL; // dx1 + + const XPUTypeT* d_residual_ptr = d_out_ptr; + + if (l3_total_size >= (dim_feedforward * bsz_seq * sizeof(T) + + d_model * bsz_seq * sizeof(T))) { + big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); + small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_l3_ptr; + d_linear2_out_ptr = big_tmp_l3_ptr; + d_dropout1_out_ptr = big_tmp_l3_ptr; + d_act_out_ptr = big_tmp_l3_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + } else if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { + big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); + small_tmp_l3_ptr = big_tmp_l3_ptr; + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); + + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_gm_ptr; + d_linear2_out_ptr = big_tmp_l3_ptr; + d_dropout1_out_ptr = big_tmp_l3_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + + } else if (l3_total_size >= d_model * bsz_seq * sizeof(T)) { + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); + + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_l3_ptr; + d_linear2_out_ptr = big_tmp_gm_ptr; + d_dropout1_out_ptr = big_tmp_gm_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + } else { + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); + d_layernorm_out_ptr = small_tmp_gm_ptr; + d_dropout2_out_ptr = small_tmp_gm_ptr; + d_linear2_out_ptr = big_tmp_gm_ptr; + d_dropout1_out_ptr = big_tmp_gm_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_gm_ptr; + } + + if (pre_layer_norm == false) { + const XPUTypeT* dropout2_out_ptr = + reinterpret_cast(dropout2_out->data()); + r = xpu::layer_norm_grad(xpu_ctx, + dropout2_out_ptr, + d_out_ptr, + d_layernorm_out_ptr, + bsz_seq, + d_model, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_variance_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + d_residual_ptr = d_layernorm_out_ptr; + } + DropoutGrad(xpu_ctx, + d_residual_ptr, + dropout2_mask_ptr, + d_dropout2_out_ptr, + dropout_param2, + bsz_seq * d_model); + // linear_grad2 + r = xpu::reduce_sum(xpu_ctx, + d_dropout2_out_ptr, + d_linear2_bias_ptr, + {bsz_seq, d_model}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + phi::XpuFcInfo linear2_fc_info; + linear2_fc_info.InitFcInfo(0, + bsz_seq, + d_model, + dim_feedforward, + false, + false, + nullptr, + nullptr, + nullptr); + + const XPUTypeT* a_1 = reinterpret_cast(NULL); + const XPUTypeT* b_1 = reinterpret_cast(NULL); + const XPUTypeT* a_2 = reinterpret_cast(NULL); + const XPUTypeT* b_2 = reinterpret_cast(NULL); + XPUTypeT* c_1 = d_linear2_out_ptr; + XPUTypeT* c_2 = d_linear2_weight_ptr; + phi::XpuFcInfo info_d_dropout1; + phi::XpuFcInfo info_dw2; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear2_fc_info, + false, + false, + dropout1_out_ptr, + linear2_weight_ptr, + d_dropout2_out_ptr); + + std::tie(info_d_dropout1, info_dw2, a_1, b_1, a_2, b_2) = fc_info; + + // if l3_total_size >= dim_feedforward * bsz_seq * sizeof(T), first transpos + if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T) && + info_dw2.trans_x) { + r = xpu::transpose(xpu_ctx, + dropout1_out_ptr, + big_tmp_l3_ptr, + {bsz_seq, dim_feedforward}, + {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + a_2 = big_tmp_l3_ptr; + info_dw2.trans_x = !info_dw2.trans_x; + info_dw2.stride_x = info_dw2.k; + } + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_dropout1, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dw2, 1.0f, true); + + // dropout_grad1 + DropoutGrad(xpu_ctx, + d_linear2_out_ptr, + dropout1_mask_ptr, + d_dropout1_out_ptr, + dropout_param1, + bsz_seq * dim_feedforward); + + // act_grad + if (act_method == "gelu") { + r = xpu::gelu_grad(xpu_ctx, + linear1_out_ptr, + linear1_out_ptr, + d_dropout1_out_ptr, + d_act_out_ptr, + bsz_seq * dim_feedforward); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); + } else if (act_method == "relu") { + r = xpu::relu_grad(xpu_ctx, + linear1_out_ptr, + linear1_out_ptr, + d_dropout1_out_ptr, + d_act_out_ptr, + bsz_seq * dim_feedforward); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently only supports gelu or relu activation functions!")); + } + + // linear1_grad + r = xpu::reduce_sum(xpu_ctx, + d_act_out_ptr, + d_linear1_bias_ptr, + {bsz_seq, dim_feedforward}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + phi::XpuFcInfo linear1_fc_info; + linear1_fc_info.InitFcInfo(0, + bsz_seq, + dim_feedforward, + d_model, + false, + false, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + + c_1 = (pre_layer_norm == true ? d_linear1_out_ptr : d_x_ptr); + c_2 = d_linear1_weight_ptr; + phi::XpuFcInfo info_dx; + phi::XpuFcInfo info_dw1; + + const XPUTypeT* linear1_x_ptr = + (pre_layer_norm == true + ? reinterpret_cast(ln1_out->data()) + : x_ptr); + + if (l3_total_size >= d_model * bsz_seq * sizeof(T) && info_dw1.trans_x) { + r = xpu::transpose( + xpu_ctx, linear1_x_ptr, small_tmp_l3_ptr, {bsz_seq, d_model}, {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + a_2 = small_tmp_l3_ptr; + info_dw1.trans_x = !info_dw1.trans_x; + info_dw1.stride_x = info_dw1.k; + } + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear1_fc_info, + false, + false, + linear1_x_ptr, + linear1_weight_ptr, + d_act_out_ptr); + + std::tie(info_dx, info_dw1, a_1, b_1, a_2, b_2) = fc_info; + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dw1, 1.0f, true); + + if (pre_layer_norm) { + r = xpu::layer_norm_grad(xpu_ctx, + x_ptr, + c_1, + c_1, + bsz_seq, + d_model, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_variance_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + } + + r = xpu::add(xpu_ctx, c_1, d_residual_ptr, d_x_ptr, d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto place = context.GetPlace(); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + // inputs + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* x = context.Input("X"); + + auto* dropout1_mask = context.Input("Dropout1Mask"); + auto* dropout2_mask = context.Input("Dropout2Mask"); + auto* linear1_out = context.Input("Linear1Out"); + auto* ln1_out = pre_layer_norm ? context.Input("Ln1Out") : nullptr; + + auto* dropout1_out = context.Input("Dropout1Out"); + auto* dropout2_out = context.Input("Dropout2Out"); + auto* linear1_weight = context.Input("Linear1Weight"); + auto* linear2_weight = context.Input("Linear2Weight"); + + const Tensor* ln_mean = nullptr; + const Tensor* ln_variance = nullptr; + const Tensor* ln_scale = nullptr; + + if (pre_layer_norm) { + ln_mean = context.Input("Ln1Mean"); + ln_variance = context.Input("Ln1Variance"); + ln_scale = context.Input("Ln1Scale"); + } else { + ln_mean = context.Input("Ln2Mean"); + ln_variance = context.Input("Ln2Variance"); + ln_scale = context.Input("Ln2Scale"); + } + + // output + auto* d_x = context.Output(framework::GradVarName("X")); + + Tensor* d_ln_scale = nullptr; + Tensor* d_ln_bias = nullptr; + + if (pre_layer_norm) { + d_ln_scale = context.Output(framework::GradVarName("Ln1Scale")); + d_ln_bias = context.Output(framework::GradVarName("Ln1Bias")); + } else { + d_ln_scale = context.Output(framework::GradVarName("Ln2Scale")); + d_ln_bias = context.Output(framework::GradVarName("Ln2Bias")); + } + + auto* d_linear1_weight = + context.Output(framework::GradVarName("Linear1Weight")); + auto* d_linear1_bias = + context.Output(framework::GradVarName("Linear1Bias")); + auto* d_linear2_weight = + context.Output(framework::GradVarName("Linear2Weight")); + auto* d_linear2_bias = + context.Output(framework::GradVarName("Linear2Bias")); + + float epsilon = 0.0f; + if (pre_layer_norm) { + epsilon = context.Attr("ln1_epsilon"); + } else { + epsilon = context.Attr("ln2_epsilon"); + } + + const std::string act_method = context.Attr("act_method"); + + XPUDropoutParam dropout_param1(context, 1); + XPUDropoutParam dropout_param2(context, 2); + + const int ring_id = context.Attr("ring_id"); + + d_x->mutable_data(place); + d_ln_scale->mutable_data(place); + d_ln_bias->mutable_data(place); + d_linear1_bias->mutable_data(place); + d_linear2_bias->mutable_data(place); + d_linear1_weight->mutable_data(place); + d_linear2_weight->mutable_data(place); + + auto x_dim = x->dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dim), 0, false); + + auto linear1_weight_dim = linear1_weight->dims(); + int d_model = linear1_weight_dim[0]; + int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + auto& dev_ctx = context.template device_context(); + + FFNGrad(dev_ctx, + d_out, + x, + dropout1_mask, + dropout2_mask, + linear1_out, + ln1_out, + dropout1_out, + dropout2_out, + linear1_weight, + linear2_weight, + ln_scale, + ln_mean, + ln_variance, + d_x, + d_linear1_weight, + d_linear1_bias, + d_linear2_weight, + d_linear2_bias, + d_ln_scale, + d_ln_bias, + bsz_seq, + d_model, + dim_feedforward, + dropout_param1, + dropout_param2, + act_method, + pre_layer_norm, + epsilon, + ring_id); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + fused_feedforward, + ops::FusedFeedForwardXPUKernel, + ops::FusedFeedForwardXPUKernel); + +REGISTER_OP_XPU_KERNEL( + fused_feedforward_grad, + ops::FusedFeedForwardGradXPUKernel, + ops::FusedFeedForwardGradXPUKernel); + +#endif diff --git a/paddle/fluid/operators/fused/xpu_fused_common_function.h b/paddle/fluid/operators/fused/xpu_fused_common_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1a1ec8c47f9bad5727986903f84d11125631b212 --- /dev/null +++ b/paddle/fluid/operators/fused/xpu_fused_common_function.h @@ -0,0 +1,224 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/platform/device/device_wrapper.h" + +namespace paddle { +namespace operators { +using Tensor = phi::DenseTensor; + +struct XPUDropoutParam { + float dropout_prob; + bool is_upscale_in_train; + bool is_test; + bool fix_seed; + const Tensor *tensor_seed; + int seed_val; + + XPUDropoutParam() { + fix_seed = false; + is_test = false; + is_upscale_in_train = false; + dropout_prob = 0.5; + tensor_seed = nullptr; + seed_val = 0; + } + + XPUDropoutParam(const framework::ExecutionContext &context, + const int dropout_index) { + std::string pre_fix = "dropout"; + std::string str_index = std::to_string(dropout_index); + if (dropout_index > 0) { + pre_fix = pre_fix + str_index + "_"; + } else { + pre_fix = pre_fix + "_"; + } + dropout_prob = context.Attr(pre_fix + "rate"); + auto &dropout_implementation = + context.Attr(pre_fix + "implementation"); + is_upscale_in_train = (dropout_implementation == "upscale_in_train"); + is_test = context.Attr("is_test"); + fix_seed = context.Attr(pre_fix + "fix_seed"); + + std::string str_seed = "Dropout"; + if (dropout_index > 0) { + str_seed = str_seed + str_index + "Seed"; + } else { + str_seed = str_seed + "Seed"; + } + + tensor_seed = + context.HasInput(str_seed) ? context.Input(str_seed) : nullptr; + if (tensor_seed) { + seed_val = *(tensor_seed->data()); + } else { + seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; + } + } + + void initXPUDropoutParam(float dropout_prob_, + bool is_upscale_in_train_, + bool is_test_, + bool fix_seed_, + const Tensor *tensor_seed, + int seed_val_) { + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + fix_seed = fix_seed_; + if (tensor_seed) { + seed_val = *(tensor_seed->data()); + } else { + seed_val = fix_seed ? seed_val_ : 0; + } + } + + void initXPUDropoutParam(const framework::ExecutionContext &context, + int dropout_index) { + std::string pre_fix = "dropout"; + std::string str_index = std::to_string(dropout_index); + if (dropout_index > 0) { + pre_fix = pre_fix + str_index + "_"; + } else { + pre_fix = pre_fix + "_"; + } + dropout_prob = context.Attr(pre_fix + "rate"); + auto &dropout_implementation = + context.Attr(pre_fix + "implementation"); + is_upscale_in_train = (dropout_implementation == "upscale_in_train"); + is_test = context.Attr("is_test"); + fix_seed = context.Attr(pre_fix + "fix_seed"); + std::string str_seed = "Dropout"; + if (dropout_index > 0) { + str_seed = str_seed + str_index + "Seed"; + } else { + str_seed = str_seed + "Seed"; + } + tensor_seed = + context.HasInput(str_seed) ? context.Input(str_seed) : nullptr; + + if (tensor_seed) { + seed_val = *(tensor_seed->data()); + } else { + seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; + } + } +}; + +/****************** + * check is l3 + *******************/ + +static bool is_in_l3(const void *addr) { + int64_t addr_int = (int64_t)addr; + int addr_int_high = addr_int >> 32; + return (addr_int_high == 0); +} + +/************************* + * dropout + *************************/ + +template +void Dropout(xpu::Context *xpu_ctx, + const T *x, + T *mask, + T *y, + const XPUDropoutParam ¶m, + int len) { + using XPUType = typename XPUTypeTrait::Type; + int r = XPU_SUCCESS; + if (param.dropout_prob == 0.0f) { + r = xpu::copy(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + if (!param.is_test) { + if (param.dropout_prob == 1.0f) { + r = xpu::constant( + xpu_ctx, reinterpret_cast(y), len, XPUType(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::constant( + xpu_ctx, reinterpret_cast(mask), len, XPUType(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + } else { + r = xpu::dropout(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(mask), + param.seed_val, + len, + param.is_upscale_in_train, + param.dropout_prob); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); + } + } else { + float scale = (param.is_upscale_in_train) + ? (1.0) + : (static_cast(1.0f - param.dropout_prob)); + r = xpu::scale(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + len, + false, + scale, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } +} + +template +void DropoutGrad(xpu::Context *xpu_ctx, + const T *dy, + const T *mask, + T *dx, + const XPUDropoutParam ¶m, + int len) { + using XPUType = typename XPUTypeTrait::Type; + if (param.dropout_prob == 0.0f) { + int r = xpu::copy(xpu_ctx, + reinterpret_cast(dy), + reinterpret_cast(dx), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + if (!param.is_upscale_in_train) { + int r = xpu::mul(xpu_ctx, + reinterpret_cast(dy), + reinterpret_cast(mask), + reinterpret_cast(dx), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); + } else { + int r = xpu::dropout_grad(xpu_ctx, + reinterpret_cast(mask), + reinterpret_cast(dy), + reinterpret_cast(dx), + param.dropout_prob, + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); + } +} + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 8773ae273a69e4521e8bee6a67536dbfd07cb5f8..cbcbde8f9ddcd2f22f1b106ecad4adf93c279436 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -704,6 +704,18 @@ XPUOpMap& get_kl2_ops() { {"fused_gemm_epilogue_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"fused_attention", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"fused_attention_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"fused_feedforward", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"fused_feedforward_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, }; return s_xpu2_kernels; diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index 8433c6b421eed026a3db3c6e7f2afa04d788737b..277a4e953d6e19fb5135eb44a51c8d59c05e958c 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -382,7 +382,8 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, const T* y, T* out, const XpuFcInfo& fcinfo, - float alpha) { + float alpha, + bool is_grad = false) { using XPUType = typename XPUTypeTrait::Type; int fccal_type = FCCalcType(); @@ -398,6 +399,12 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, }; auto fc_api = fc_api_list[fccal_type]; + if (std::getenv("XPU_PADDLE_FC_GRAD_LOCAL") != nullptr) { + if (is_grad) { + fc_api = fc_api_list[2]; + } + } + auto fc_batch_api = fc_batch_api_list[fccal_type]; int m = fcinfo.m; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..6462bec102ee50314bdd6e078e5e5bd307a453aa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py @@ -0,0 +1,331 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import sys + +sys.path.append("..") + +import paddle +import paddle.nn.functional as F +import paddle.incubate.nn.functional as incubate_f +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test_xpu import XPUOpTest +from paddle.fluid.framework import default_main_program + +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +default_main_program().random_seed = 42 + + +class XPUTestFusedAttentionOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'fused_attention' + self.use_dynamic_create_class = False + + class TestFusedAttentionOp(XPUOpTest): + def setUp(self): + self.config() + self.generate_input_data() + self.rtol = 1e-5 + self.atol = 1e-3 + if self.x_type == np.float16 or str(self.x_type) == "float16": + self.atol = 1e-1 + + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_attention" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = True + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = self.in_type + self.attn_mask_type = np.float32 + self.pre_layer_norm = True + self.has_attn_mask = False + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.cache_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = ( + self.query_length, + self.query_length, + ) + + def generate_input_data(self): + self.query = np.random.rand( + self.batch_size, self.query_length, self.embed_dim + ).astype(self.x_type) + out_seq_len = self.key_length + if self.has_attn_mask: + # [B, n_head, seq_len, out_seq_len] + self.attn_mask = np.ones( + ( + self.batch_size, + self.num_heads, + self.query_length, + out_seq_len, + ), + dtype=self.attn_mask_type, + ) + else: + self.attn_mask = None + self.key, self.value = self.query, self.query + + self.dout = np.random.random( + (self.batch_size, self.query_length, self.embed_dim) + ).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static() + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + + if self.has_attn_mask: + attn_mask = paddle.to_tensor( + self.attn_mask, stop_gradient=False + ) + else: + attn_mask = None + residual = tensor_query + + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, out_seq_len] + qk_out = layers.matmul( + x=q_out * self.head_dim**-0.5, y=k_out, transpose_y=True + ) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train", + ) + # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, head_dim] + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]] + ) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + else: + final_out = residual_out + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True + ) + return final_out, tensor_query.grad + + def GetFusedAttentionOut(self): + paddle.disable_static() + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False + ) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False + ) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False + ) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False + ) + + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False + ) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False + ) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False + ) + qkv_bias = np.concatenate( + ( + q_proj_bias.numpy(), + k_proj_bias.numpy(), + v_proj_bias.numpy(), + ) + ) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor( + qkv_bias, stop_gradient=False + ) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False + ) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight) + ) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim) + ) + + x = paddle.to_tensor(self.query, stop_gradient=False) + cache_kv = None + if self.has_attn_mask: + attn_mask = paddle.to_tensor( + self.attn_mask, stop_gradient=False + ) + else: + attn_mask = None + qkv_weight_tensor = paddle.to_tensor( + qkv_weight, stop_gradient=False + ) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = incubate_f.fused_multi_head_attention( + x, + qkv_weight_tensor, + out_linear_weight, + self.pre_layer_norm, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + epsilon, + qkv_bias_tensor, + out_linear_bias, + cache_kv, + attn_mask, + self.dropout_prob, + self.attn_dropout_prob, + ln2_epsilon, + ) + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True + ) + return final_out, x.grad + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol + ) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol + ) + + class TestFusedAttentionOpPreLn(TestFusedAttentionOp): + def config(self): + super().config() + self.pre_layer_norm = True + + class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): + def config(self): + super().config() + self.pre_layer_norm = True + self.has_attn_mask = False + + +support_types = get_xpu_op_support_types('fused_attention') +for stype in support_types: + create_test_class(globals(), XPUTestFusedAttentionOp, stype) + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a6fb75eba0e3078122f2cbfd54f3d3bf9a2c26 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py @@ -0,0 +1,379 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import sys + +sys.path.append("..") + +import paddle +from paddle.nn.layer import transformer +import paddle.nn.functional as F +import paddle.incubate.nn.functional as incubate_f +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +import unittest +from op_test_xpu import XPUOpTest +from paddle.fluid.framework import default_main_program + +from xpu.get_test_cover_info import ( + create_test_class, + XPUOpTestWrapper, +) + + +class XPUTestFusedFFNOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'fused_feedforward' + self.use_dynamic_create_class = False + + class TestFusedFFNOp(XPUOpTest): + def getDtype(self): + self.dtype = self.in_type + self.layer_norm_dtype = "float32" + + def getShape(self): + self.batch_size = np.random.randint(1, 32) + self.query_length = np.random.randint(32, 128) + self.d_model = np.random.randint(32, 512) + self.dim_feedforward = np.random.randint(32, 512) + + def getDiff(self): + self.rtol = 1e-2 + self.atol = 1e-3 + if self.dtype == np.float16 or self.dtype == "float16": + self.atol = 1e-1 + + def getActivation(self): + self.act_method = "gelu" + + def getNormalizeBefore(self): + self.pre_layer_norm = False + + def setUp(self): + paddle.disable_static() + self.__class__.op_type = "fused_feedforward" + # check grad in test_out_and_grad() + self.__class__.no_need_check_grad = True + self.getDtype() + self.getShape() + self.getDiff() + self.getActivation() + self.getNormalizeBefore() + paddle.set_default_dtype(self.dtype) + self.weight_attr = None + self.bias_attr = None + + self.weight_attrs = transformer._convert_param_attr_to_list( + self.weight_attr, 2 + ) + self.bias_attrs = transformer._convert_param_attr_to_list( + self.bias_attr, 2 + ) + self.linear1 = Linear( + self.d_model, + self.dim_feedforward, + self.weight_attrs[1], + bias_attr=self.bias_attrs[1], + ) + self.linear2 = Linear( + self.dim_feedforward, + self.d_model, + self.weight_attrs[1], + bias_attr=self.bias_attrs[1], + ) + + paddle.set_default_dtype(self.layer_norm_dtype) + self.norm1 = LayerNorm(self.d_model) + self.norm2 = LayerNorm(self.d_model) + paddle.set_default_dtype(self.dtype) + self.dropout1 = Dropout(0.0, mode="upscale_in_train") + self.dropout2 = Dropout(0.0, mode="upscale_in_train") + self.activation = getattr(F, self.act_method) + + self.src = np.random.random( + (self.batch_size, self.query_length, self.d_model) + ).astype(self.dtype) + self.dout = np.random.random( + (self.batch_size, self.query_length, self.d_model) + ).astype(self.dtype) + + def Base(self): + paddle.disable_static() + tensor_src = paddle.to_tensor(self.src, stop_gradient=False) + residual = tensor_src + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_src) + linear2_out = self.linear2( + self.dropout1(self.activation(self.linear1(ln1_out))) + ) + dropout2_out = residual + self.dropout2(linear2_out) + paddle.autograd.backward( + [dropout2_out], [paddle.to_tensor(self.dout)], True + ) + return dropout2_out, tensor_src.grad + else: + linear2_out = self.linear2( + self.dropout1(self.activation(self.linear1(tensor_src))) + ) + dropout2_out = residual + self.dropout2(linear2_out) + dropout2_out = self.norm2(dropout2_out) + paddle.autograd.backward( + [dropout2_out], [paddle.to_tensor(self.dout)], True + ) + return dropout2_out, tensor_src.grad + + def FusedFFN(self): + paddle.disable_static() + linear1_weight = paddle.to_tensor( + self.linear1.weight, stop_gradient=False + ) + linear1_bias = paddle.to_tensor( + self.linear1.bias, stop_gradient=False + ) + linear2_weight = paddle.to_tensor( + self.linear2.weight, stop_gradient=False + ) + linear2_bias = paddle.to_tensor( + self.linear2.bias, stop_gradient=False + ) + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + x = paddle.to_tensor(self.src, stop_gradient=False) + out = incubate_f.fused_feedforward( + x, + linear1_weight, + linear2_weight, + linear1_bias, + linear2_bias, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + 0.0, + 0.0, + activation=self.act_method, + pre_layer_norm=self.pre_layer_norm, + ) + paddle.autograd.backward([out], [paddle.to_tensor(self.dout)]) + return out, x.grad + + def test_out_and_grad(self): + default_main_program().random_seed = 42 + base_out, base_grad = self.Base() + fused_out, fused_grad = self.FusedFFN() + np.testing.assert_allclose( + base_out.numpy(), + fused_out.numpy(), + rtol=self.rtol, + atol=self.atol, + ) + np.testing.assert_allclose( + base_grad.numpy(), + fused_grad.numpy(), + rtol=self.rtol, + atol=self.atol, + ) + + class TestFusedFFNOpActivation(TestFusedFFNOp): + def getActivation(self): + self.act_method = "relu" + + class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp): + def getNormalizeBefore(self): + self.pre_layer_norm = True + + def getShape(self): + self.batch_size = 1 + self.query_length = 1 + self.d_model = 8 + self.dim_feedforward = 8 + + +class APITestStaticFusedFFN(unittest.TestCase): + def test_static(self): + paddle.enable_static() + default_main_program().random_seed = 42 + dtype = "float32" + layer_norm_dtype = "float32" + batch_size = 1 + d_model = 8 + dim_feedforward = 8 + + x = paddle.static.data( + name='x', shape=[batch_size, d_model, dim_feedforward], dtype=dtype + ) + linear1_weight = paddle.static.data( + name='linear1_weight', shape=[d_model, dim_feedforward], dtype=dtype + ) + linear1_bias = paddle.static.data( + name='linear1_bias', shape=[dim_feedforward], dtype=dtype + ) + linear2_weight = paddle.static.data( + name='linear2_weight', shape=[dim_feedforward, d_model], dtype=dtype + ) + linear2_bias = paddle.static.data(name='linear2_bias', shape=[d_model]) + ln1_scale = paddle.static.data(name='ln1_scale', shape=[d_model]) + ln1_bias = paddle.static.data(name='ln1_scale', shape=[d_model]) + ln2_scale = paddle.static.data(name='ln2_scale', shape=[d_model]) + ln2_bias = paddle.static.data(name='ln2_scale', shape=[d_model]) + + fused_out = incubate_f.fused_feedforward( + x, + linear1_weight, + linear2_weight, + linear1_bias, + linear2_bias, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + 0.0, + 0.0, + activation="relu", + pre_layer_norm=False, + ) + + linear1_out = F.linear(x, linear1_weight, linear1_bias) + act_out = F.relu(linear1_out) + dropout1_out = F.dropout(x=act_out, p=0.0, training=False) + linear2_out = F.linear(dropout1_out, linear2_weight, linear2_bias) + dropout2_out = x + F.dropout(x=linear2_out, p=0.0, training=False) + ln_out = F.layer_norm( + dropout2_out, + normalized_shape=list([d_model]), + weight=ln2_scale, + bias=ln2_bias, + ) + + exe = paddle.static.Executor(paddle.XPUPlace(0)) + + x_data = np.random.random( + (batch_size, d_model, dim_feedforward) + ).astype(dtype) + linear1_weight_data = np.random.random( + (d_model, dim_feedforward) + ).astype(dtype) + linear1_bias_data = np.zeros((dim_feedforward)).astype(dtype) + linear2_weight_data = np.random.random( + (dim_feedforward, d_model) + ).astype(dtype) + linear2_bias_data = np.zeros((d_model)).astype(dtype) + + ln1_scale_data = np.ones((d_model)).astype(layer_norm_dtype) + ln1_bias_data = np.zeros((d_model)).astype(layer_norm_dtype) + ln2_scale_data = np.ones((d_model)).astype(layer_norm_dtype) + ln2_bias_data = np.zeros((d_model)).astype(layer_norm_dtype) + + res_list = [fused_out, ln_out] + real_res = [] + + for res in res_list: + fetch = exe.run( + feed={ + 'x': x_data, + 'linear1_weight': linear1_weight_data, + 'linear1_bias': linear1_bias_data, + 'linear2_weight': linear2_weight_data, + 'linear2_bias': linear2_bias_data, + 'ln1_scale': ln1_scale_data, + 'ln1_bias': ln1_bias_data, + 'ln2_scale': ln2_scale_data, + 'ln2_bias': ln2_bias_data, + }, + fetch_list=[res], + ) + real_res.append(fetch) + np.testing.assert_allclose( + real_res[0], real_res[1], rtol=1e-05, atol=0.001 + ) + + +class TestFusedFFNOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + + def test_dtype(): + x = paddle.static.data( + name='x', shape=[1, 10, 10], dtype="int32" + ) + linear1_weight = paddle.static.data( + name='linear1_weight', shape=[1, 10, 10], dtype="float32" + ) + linear2_weight = paddle.static.data( + name='linear2_weight', shape=[1, 10, 10], dtype="float32" + ) + incubate_f.fused_feedforward(x, linear1_weight, linear2_weight) + + self.assertRaises(TypeError, test_dtype) + + def test_dropout_rate_type(): + x = paddle.static.data( + name='x1', shape=[1, 10, 10], dtype="float32" + ) + linear1_weight = paddle.static.data( + name='linear1_weight1', shape=[10, 10], dtype="float32" + ) + linear2_weight = paddle.static.data( + name='linear2_weight1', shape=[10, 10], dtype="float32" + ) + incubate_f.fused_feedforward( + x, linear1_weight, linear2_weight, dropout1_rate="a" + ) + + self.assertRaises(TypeError, test_dropout_rate_type) + + def test_dropout_rate_value(): + x = paddle.static.data( + name='x2', shape=[1, 10, 10], dtype="float32" + ) + linear1_weight = paddle.static.data( + name='linear1_weight2', shape=[10, 10], dtype="float32" + ) + linear2_weight = paddle.static.data( + name='linear2_weight2', shape=[10, 10], dtype="float32" + ) + incubate_f.fused_feedforward( + x, linear1_weight, linear2_weight, dropout2_rate=-1 + ) + + self.assertRaises(ValueError, test_dropout_rate_value) + + def test_dropout_mode(): + x = paddle.static.data( + name='x3', shape=[1, 10, 10], dtype="float32" + ) + linear1_weight = paddle.static.data( + name='linear1_weight3', shape=[10, 10], dtype="float32" + ) + linear2_weight = paddle.static.data( + name='linear2_weight3', shape=[10, 10], dtype="float32" + ) + incubate_f.fused_feedforward( + x, linear1_weight, linear2_weight, mode='test' + ) + + self.assertRaises(ValueError, test_dropout_mode) + + +support_types = {"float32"} # get_xpu_op_support_types('fused_feedforward') +for stype in support_types: + create_test_class(globals(), XPUTestFusedFFNOp, stype) + +if __name__ == "__main__": + unittest.main()