未验证 提交 7b56bd25 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移XPU OpKernel] [ test=kunlun ] (#53011)

* trans fused attention to phi

* add optional parm

* trans fused_attention_grad to phi

* add fused attention grad register info

* fix include

* test=kunlun

* add fused attention to static build list

* add remove

* update remove
上级 543efcc5
...@@ -33,8 +33,6 @@ std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = { ...@@ -33,8 +33,6 @@ std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"cudnn_lstm", "cudnn_lstm",
"dequantize", "dequantize",
"distributed_fused_lamb", "distributed_fused_lamb",
"fused_attention",
"fused_attention_grad",
"fused_batch_norm_act", "fused_batch_norm_act",
"fused_batch_norm_act_grad", "fused_batch_norm_act_grad",
"fusion_group", "fusion_group",
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedAttentionGradKernel(
const Context &dev_ctx,
const DenseTensor &out_grad,
const DenseTensor &x,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &qkv_bias_out,
const paddle::optional<DenseTensor> &src_mask,
const paddle::optional<DenseTensor> &src_mask_out,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
const paddle::optional<DenseTensor> &ln_out,
const paddle::optional<DenseTensor> &ln_mean,
const paddle::optional<DenseTensor> &ln_var,
const paddle::optional<DenseTensor> &ln_mean_2,
const paddle::optional<DenseTensor> &ln_var_2,
const paddle::optional<DenseTensor> &bias_dropout_residual_out,
const DenseTensor &qkv_out,
const DenseTensor &transpose_out_2,
const DenseTensor &qk_out,
const DenseTensor &qktv_out,
const DenseTensor &softmax_out,
const DenseTensor &attn_dropout_mask_out,
const DenseTensor &attn_dropout_out,
const DenseTensor &fmha_out,
const DenseTensor &out_linear_out,
const DenseTensor &dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *qkv_bias_grad,
DenseTensor *qkv_bias_out_grad,
DenseTensor *src_mask_out_grad,
DenseTensor *out_linear_bias_grad,
DenseTensor *ln_scale_grad,
DenseTensor *ln_bias_grad,
DenseTensor *ln_scale_2_grad,
DenseTensor *ln_bias_2_grad,
DenseTensor *x_grad,
DenseTensor *qkv_weight_grad,
DenseTensor *out_linear_weight_grad,
DenseTensor *ln_out_grad,
DenseTensor *bias_dropout_residual_out_grad,
DenseTensor *qkv_out_grad,
DenseTensor *qktv_out_grad,
DenseTensor *transpose_out_2_grad,
DenseTensor *qk_out_grad,
DenseTensor *softmax_out_grad,
DenseTensor *attn_dropout_out_grad,
DenseTensor *fmha_out_grad,
DenseTensor *out_linear_out_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
/**
* @brief Fused Attention Kernel.
* @param ctx device context
* @param x The input tensor.
* @param ln_scale (optional) Scale is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param ln_bias (optional) Bias is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param qkv_weight The qkv weight tensor.
* @param qkv_bias The qkv bias tensor.
* @param cache_kv (optional) The cache KV for generation inference.
* @param src_mask (optional) The attention mask tensor in fmha.
* @param out_linear_w The out_linear weight tensor.
* @param out_linear_bias (optional) The out_linear bias tensor.
* @param ln_scale_2 (optional) Scale is a 1-dimensional tensor of
* size H. Here, H represents the last dimension of its input tensor.
* @param ln_bias_2 (optional) Bias is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param num_heads The number head for multi_head_attention.
* @param transpose_qkv_wb The qkv_w shape is (h, 3h), do transpose to it.
* @param pre_layer_norm if true, the attention op uses pre_layer_norm
* architecure, else, uses post_layer_norm
* architecuture. [default false].
* @param epsilon Constant for numerical stability [default 1e-5].
* @param attn_dropout_rate Probability of setting units to zero.
* @param is_test (bool, default false) Set to true for inference
* only, false " for training. Some layers may run
* faster when this is true.
* @param attn_dropout_fix_seed A flag indicating whether to use a fixed seed to
* generate " random mask. NOTE: DO NOT set this flag
* to true in training. Setting this flag to true is
* only useful in unittest or for debug that always the same output units will
* be dropped."
* @param attn_dropout_seed Dropout random seed.
* @param attn_dropout_implementation ["downgrade_in_infer"|"upscale_in_train"]
* There are two kinds of ways to implement dropout
* (the mask below is a tensor have the same shape
* with input the value of mask is 0 or 1, the ratio of 0 is
* dropout_rate)
* 1. downgrade_in_infer(default), downgrade the
* outcome at inference time train: out = input *
* mask inference: out = input * (1.0 - dropout_rate)
* 2. upscale_in_train, upscale the outcome at
* training time, do nothing in inference train:
* out = input * mask / ( 1.0 - dropout_rate ) inference: out = input dropout op
* can be removed from the program. the program will be efficient
* @param dropout_rate Probability of setting units to zero.
* @param dropout_fix_seed A flag indicating whether to use a fixed seed to
* generate " random mask. NOTE: DO NOT set this flag
* to true in training. Setting this flag to true is
* only useful in unittest or for debug that always the same output units will
* be dropped.
* @param dropout_seed Dropout random seed.
* @param dropout_implementation dropout_implementation
* ["downgrade_in_infer"|"upscale_in_train"] The
* meaning is the same as
* 'attn_dropout_implementation'
* @param ln_epsilon Constant for numerical stability [default 1e-5].
* @param add_residual Whether to add residual.
* @param ring_id ring id for tensor model parallel. distributed
* training and inference
* @param ln_mean Mean of the current mini batch.
* @param ln_var Variance of the current mini batch.
* @param ln_out The output tensor after layer_norm.
* @param qkv_out Result after qkv.
* @param qkv_bias_out Result after qkv and bias op.
* @param transpose_out_2 Result in fmha.
* @param qk_out Result in fmha.
* @param qktv_out Result in fmha.
* @param soft_max_out Result in fmha.
* @param attn_dropout_mask_out Result in fmha.
* @param attn_dropout_out Result in fmha.
* @param src_mask_out Result in fmha.
* @param fmha_out Result in fmha.
* @param out_linear_out Result after out_linear.
* @param dropout_mask_out The random sampled dropout mask.
* @param ln_mean_2 Mean of the current mini batch.
* @param ln_var_2 Variance of the current mini batch.
* @param bias_dropout_residual_out Result of residual + dropout(src + bias).
* @param cache_kv_out The update cache KV.
* @param y Result after attention.
*/
template <typename T, typename Context>
void FusedAttentionKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &cache_kv,
const paddle::optional<DenseTensor> &src_mask,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *ln_mean,
DenseTensor *ln_var,
DenseTensor *ln_out,
DenseTensor *qkv_out,
DenseTensor *qkv_bias_out,
DenseTensor *transpose_out_2,
DenseTensor *qk_out,
DenseTensor *qktv_out,
DenseTensor *softmax_out,
DenseTensor *attn_dropout_mask_out,
DenseTensor *attn_dropout_out,
DenseTensor *src_mask_out,
DenseTensor *fmha_out,
DenseTensor *out_linear_out,
DenseTensor *dropout_mask_out,
DenseTensor *ln_mean_2,
DenseTensor *ln_var_2,
DenseTensor *bias_dropout_residual_out,
DenseTensor *cache_kv_out,
DenseTensor *out);
} // namespace phi
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_attention_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h"
namespace phi {
template <typename T, typename Context>
void FusedAttentionKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &cache_kv,
const paddle::optional<DenseTensor> &src_mask,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *ln_mean,
DenseTensor *ln_var,
DenseTensor *ln_out,
DenseTensor *qkv_out,
DenseTensor *qkv_bias_out,
DenseTensor *transpose_out_2,
DenseTensor *qk_out,
DenseTensor *qktv_out,
DenseTensor *softmax_out,
DenseTensor *attn_dropout_mask_out,
DenseTensor *attn_dropout_out,
DenseTensor *src_mask_out,
DenseTensor *fmha_out,
DenseTensor *out_linear_out,
DenseTensor *dropout_mask_out,
DenseTensor *ln_mean_2,
DenseTensor *ln_var_2,
DenseTensor *bias_dropout_residual_out,
DenseTensor *cache_kv_out,
DenseTensor *out) {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
// shape [batch_size, 1, 1, seq_len]
const phi::DenseTensor *src_mask_p = src_mask.get_ptr();
const phi::DenseTensor *ln_scale_p = nullptr;
const phi::DenseTensor *ln_bias_p = nullptr;
if (pre_layer_norm) {
ln_scale_p = ln_scale.get_ptr();
ln_bias_p = ln_bias.get_ptr();
} else {
ln_scale_p = ln_scale_2.get_ptr();
ln_bias_p = ln_bias_2.get_ptr();
epsilon = ln_epsilon;
}
dev_ctx.template Alloc<T>(qk_out);
dev_ctx.template Alloc<T>(qktv_out);
dev_ctx.template Alloc<T>(out_linear_out);
dev_ctx.template Alloc<T>(qkv_bias_out);
dev_ctx.template Alloc<T>(src_mask_out);
dev_ctx.template Alloc<T>(qkv_out);
bool is_upscale_in_train_1 =
(attn_dropout_implementation == "upscale_in_train");
const phi::DenseTensor *seed_1 = nullptr;
phi::XPUDropoutParam attn_dropout_param;
attn_dropout_param.initXPUDropoutParam(attn_dropout_rate,
is_upscale_in_train_1,
is_test,
attn_dropout_fix_seed,
seed_1,
attn_dropout_seed);
phi::XPUDropoutParam dropout_param;
dropout_param.initXPUDropoutParam(dropout_rate,
is_upscale_in_train_1,
is_test,
dropout_fix_seed,
seed_1,
dropout_seed);
// 先计算纬度
const auto input_x_dims = x.dims();
const auto qkv_w_dims = qkv_weight.dims();
int batch_size = input_x_dims[0];
int seq_len = input_x_dims[1];
int embed_dims = input_x_dims[2];
num_heads = qkv_w_dims[1];
int head_dims = qkv_w_dims[2];
// 输入指针
const XPUTypeT *input_x_ptr = reinterpret_cast<const XPUTypeT *>(x.data<T>());
const XPUTypeT *qkv_weight_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_weight.data<T>());
const DenseTensor *qkv_bias_p = qkv_bias.get_ptr();
const XPUTypeT *qkv_bias_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_bias_p->data<T>());
const XPUTypeT *src_mask_ptr =
(src_mask_p == nullptr)
? (nullptr)
: (reinterpret_cast<const XPUTypeT *>(src_mask_p->data<T>()));
const XPUTypeT *out_linear_weight_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_weight.data<T>());
const DenseTensor *out_linear_bias_p = out_linear_bias.get_ptr();
const XPUTypeT *out_linear_bias_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_bias_p->data<T>());
const float *ln_scale_ptr =
(ln_scale_p == nullptr) ? (nullptr) : ln_scale_p->data<float>();
const float *ln_bias_ptr =
(ln_bias_p == nullptr) ? (nullptr) : ln_bias_p->data<float>();
// 输出指针
XPUTypeT *qkv_transpose_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(transpose_out_2));
XPUTypeT *softmax_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(softmax_out));
XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(attn_dropout_mask_out));
XPUTypeT *attn_dropout_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(attn_dropout_out));
XPUTypeT *fmha_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(fmha_out));
XPUTypeT *dropout_mask_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(dropout_mask_out));
XPUTypeT *out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(out));
XPUTypeT *bias_dropout_residual_out_ptr =
(bias_dropout_residual_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(bias_dropout_residual_out)));
float *ln_mean_ptr =
(ln_mean == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_mean));
float *ln_var_ptr =
(ln_var == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_var));
XPUTypeT *ln_out_ptr =
(ln_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(ln_out)));
xpu::Context *xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
int l3_total_size = xpu_ctx->_l3_mgr.get_size();
XPUTypeT *qkv_before_transpos_ptr =
NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims]
XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len]
XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims]
XPUTypeT *linear_out_ptr = NULL; // x4, x5 [batch_size, seq_len, embed_dims]
int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims;
int temp_size_2 = batch_size * num_heads * seq_len * seq_len;
int temp_size_3 = batch_size * num_heads * seq_len * head_dims;
int temp_size_4 = batch_size * seq_len * embed_dims;
std::vector<int> temp_vec = {
temp_size_1, temp_size_2, temp_size_3, temp_size_4};
std::sort(temp_vec.begin(), temp_vec.end(), std::greater<int>());
XPUTypeT *max_gm_ptr = RAII_GUARD.alloc<XPUTypeT>(temp_vec[0]);
PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr);
qkv_before_transpos_ptr = max_gm_ptr;
qk_ptr = max_gm_ptr;
qkv_ptr = max_gm_ptr;
linear_out_ptr = max_gm_ptr;
int sizeof_t = sizeof(XPUTypeT);
for (size_t i = 0; i < temp_vec.size(); ++i) {
if (l3_total_size >= temp_vec[i] * sizeof_t) {
XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3<XPUTypeT>(temp_vec[i]);
qkv_before_transpos_ptr =
(temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
break;
}
}
int r = 0;
const XPUTypeT *x_cacl_ptr = input_x_ptr;
if (pre_layer_norm) {
r = xpu::layer_norm(xpu_ctx,
input_x_ptr,
ln_out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
x_cacl_ptr = ln_out_ptr;
}
// fc
phi::XpuFcInfo qkv_fc_info;
qkv_fc_info.InitFcInfo(0,
batch_size * seq_len,
3 * num_heads * head_dims,
embed_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
x_cacl_ptr,
qkv_weight_ptr,
qkv_before_transpos_ptr,
qkv_fc_info,
1.0f);
// bias
r = xpu::broadcast_add(xpu_ctx,
qkv_before_transpos_ptr,
qkv_bias_ptr,
qkv_before_transpos_ptr,
{batch_size * seq_len, 3 * num_heads * head_dims},
{3 * num_heads * head_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
// transpose
r = xpu::transpose(xpu_ctx,
qkv_before_transpos_ptr,
qkv_transpose_out_ptr,
{batch_size, seq_len, 3, num_heads, head_dims},
{2, 0, 3, 1, 4});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
int qkv_every_size = batch_size * seq_len * num_heads * head_dims;
{
float alpha = 1.0 / sqrt(head_dims);
r = scale(xpu_ctx,
qkv_transpose_out_ptr,
qkv_transpose_out_ptr,
qkv_every_size,
false,
alpha,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
// begin fhma
// 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos
{
const XPUTypeT *q_ptr = qkv_transpose_out_ptr;
const XPUTypeT *k_ptr = q_ptr + qkv_every_size;
const XPUTypeT *v_ptr = k_ptr + qkv_every_size;
phi::XpuFcInfo qk_fc_info;
qk_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
seq_len,
head_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f);
if (src_mask_ptr) {
r = xpu::broadcast_add(xpu_ctx,
qk_ptr,
src_mask_ptr,
qk_ptr,
{batch_size, num_heads, seq_len, seq_len},
{batch_size, 1, 1, seq_len});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
// do softmax
r = xpu::softmax(xpu_ctx,
qk_ptr,
softmax_out_ptr,
{batch_size, num_heads, seq_len, seq_len},
3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// do dropout
phi::Dropout<XPUTypeT>(xpu_ctx,
softmax_out_ptr,
attn_dropout_mask_out_ptr,
attn_dropout_out_ptr,
attn_dropout_param,
batch_size * num_heads * seq_len * seq_len);
phi::XpuFcInfo qktv_fc_info;
qktv_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
head_dims,
seq_len,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f);
r = xpu::transpose(xpu_ctx,
qkv_ptr,
fmha_out_ptr,
{batch_size, num_heads, seq_len, head_dims},
{0, 2, 1, 3});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
// linear_out
phi::XpuFcInfo linear_fc_info;
linear_fc_info.InitFcInfo(0,
batch_size * seq_len,
embed_dims,
embed_dims,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
fmha_out_ptr,
out_linear_weight_ptr,
linear_out_ptr,
linear_fc_info,
1.0f);
// out_linear_bias_ptr
r = xpu::broadcast_add(xpu_ctx,
linear_out_ptr,
out_linear_bias_ptr,
linear_out_ptr,
{batch_size * seq_len, embed_dims},
{embed_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
Dropout(xpu_ctx,
linear_out_ptr,
dropout_mask_out_ptr,
linear_out_ptr,
dropout_param,
batch_size * seq_len * embed_dims);
XPUTypeT *real_out_ptr = out_ptr;
if (pre_layer_norm == false) {
real_out_ptr = bias_dropout_residual_out_ptr;
}
r = xpu::add(xpu_ctx,
linear_out_ptr,
input_x_ptr,
real_out_ptr,
batch_size * seq_len * embed_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
if (pre_layer_norm == false) {
r = xpu::layer_norm(xpu_ctx,
real_out_ptr,
out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_attention,
XPU,
ALL_LAYOUT,
phi::FusedAttentionKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
struct XPUDropoutParam {
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
const phi::DenseTensor *tensor_seed;
int seed_val;
XPUDropoutParam() {
fix_seed = false;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
void initXPUDropoutParam(float dropout_prob_,
bool is_upscale_in_train_,
bool is_test_,
bool fix_seed_,
const phi::DenseTensor *tensor_seed,
int seed_val_) {
dropout_prob = dropout_prob_;
is_upscale_in_train = is_upscale_in_train_;
is_test = is_test_;
fix_seed = fix_seed_;
if (tensor_seed) {
seed_val = *(tensor_seed->data<int>());
} else {
seed_val = fix_seed ? seed_val_ : 0;
}
}
};
/******************
* check is l3
******************/
static bool is_in_l3(const void *addr) {
int64_t addr_int = (int64_t)addr;
int addr_int_high = addr_int >> 32;
return (addr_int_high == 0);
}
/*************************
* dropout
*************************/
template <typename T>
void Dropout(xpu::Context *xpu_ctx,
const T *x,
T *mask,
T *y,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
if (param.dropout_prob == 0.0f) {
r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_test) {
if (param.dropout_prob == 1.0f) {
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(y), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(mask), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
} else {
r = xpu::dropout(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
reinterpret_cast<XPUType *>(mask),
param.seed_val,
len,
param.is_upscale_in_train,
param.dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
}
} else {
float scale = (param.is_upscale_in_train)
? (1.0)
: (static_cast<float>(1.0f - param.dropout_prob));
r = xpu::scale(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len,
false,
scale,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}
template <typename T>
void DropoutGrad(xpu::Context *xpu_ctx,
const T *dy,
const T *mask,
T *dx,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (param.dropout_prob == 0.0f) {
int r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_upscale_in_train) {
int r = xpu::mul(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else {
int r = xpu::dropout_grad(xpu_ctx,
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
param.dropout_prob,
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad");
}
}
} // namespace phi
...@@ -1160,6 +1160,8 @@ set(STATIC_BUILD_TESTS ...@@ -1160,6 +1160,8 @@ set(STATIC_BUILD_TESTS
test_eigh_op test_eigh_op
test_fake_quantize_op test_fake_quantize_op
test_fetch_lod_tensor_array test_fetch_lod_tensor_array
test_fused_attention_op
test_fused_attention_op_api
test_imperative_optimizer test_imperative_optimizer
test_lamb_op test_lamb_op
test_layer_norm_op test_layer_norm_op
...@@ -1186,6 +1188,11 @@ set(STATIC_BUILD_TESTS ...@@ -1186,6 +1188,11 @@ set(STATIC_BUILD_TESTS
test_while_op test_while_op
test_one_hot_v2_op) test_one_hot_v2_op)
if(NOT WITH_GPU)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api)
endif()
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
py_test_modules( py_test_modules(
${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS ${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册