未验证 提交 ec857b85 编写于 作者: Y Yuang Liu 提交者: GitHub

Add transpose_qkv_wb flags to the fused_attention_op. (#49494)

上级 11f5848b
...@@ -108,36 +108,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -108,36 +108,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
int num_heads = ctx->Attrs().Get<int>("num_heads");
bool transpose_qkv_wb = ctx->Attrs().Get<bool>("transpose_qkv_wb");
// x: qkv's input [batch_size, seq_len, dim_embed] // x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
// y: qkv's weight: [3, num_head, dim_head, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed]
// if transpose_qkv_wb is True
// y: qkv's weight: [dim_embed, 3 * dim_embed]
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("QKVW"); auto y_dim = ctx->GetInputDim("QKVW");
PADDLE_ENFORCE_EQ( int dim_head;
x_dim.size(), int hidden_size;
3, if (transpose_qkv_wb) {
platform::errors::InvalidArgument("The dimensions of x must be 3" PADDLE_ENFORCE_EQ(y_dim.size(),
"(batch_size, seq_len, dim_embed)," 2,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2 if enable"
"transpose_qkv_wb: (dim_embed, 3 * dim_embed),"
"but received dimensions of" "but received dimensions of"
"Input is [%d]", "Input is [%d]",
x_dim.size())); y_dim.size()));
PADDLE_ENFORCE_GT(num_heads,
0,
platform::errors::InvalidArgument(
"The num_heads must be provided and greater than 0 "
"if enable transpose_qkv_wb, but we got %d.",
num_heads));
PADDLE_ENFORCE_EQ(y_dim[0] % num_heads,
0,
platform::errors::InvalidArgument(
"First dim of qkv_w must be divisible by num heads "
"if enable transpose_qkv_wb, but receive first "
"dim of qkv_w is %d and num_heads is %d.",
y_dim[0],
num_heads));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[0] * 3,
y_dim[1],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2"
"(dim_embed, 3 * dim_embed)."));
}
dim_head = y_dim[0] / num_heads;
hidden_size = y_dim[0];
} else {
PADDLE_ENFORCE_EQ(y_dim.size(), PADDLE_ENFORCE_EQ(y_dim.size(),
4, 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4" "The dimensions of qkv_weight must be 4 if not"
"(3, num_head, dim_head, dim_embed)," "enable transpose_qkv_wb: (3, num_head, dim_head, "
"but received dimensions of" "dim_embed), but received [%d]",
"Input is [%d]",
y_dim.size())); y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2], PADDLE_ENFORCE_EQ(y_dim[0],
y_dim[3], 3,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]" "First dim of qkv_w must be 3 if disable "
"must be equal. But received: the shape " "transpose_qkv_wb, but we got %d.",
"of input x = [%s], and the shape of " y_dim[0]));
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) { if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3], y_dim[3],
...@@ -147,6 +175,29 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -147,6 +175,29 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"and must satisfy the limitations: " "and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)")); "(num_head * dim_head == dim_embed)"));
} }
num_heads = y_dim[1];
dim_head = y_dim[2];
hidden_size = y_dim[3];
}
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
platform::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2],
hidden_size,
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3] "
"(y_dim[1] if enable transpose_qkv_w) "
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
...@@ -157,17 +208,27 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -157,17 +208,27 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
} }
if (transpose_qkv_wb) {
// [batch_size, seq_len, 3 * hidden_size]
ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size});
}
} else {
// [batch_size, seq_len, 3, num_head, head_size] // [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3, num_heads, dim_head});
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
if (ctx->HasInput("QKVBias")) { if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); {x_dim[0], x_dim[1], 3, num_heads, dim_head});
}
} }
// [3, batch_size, num_head, seq_len, head_size] // [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2", ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); {3, x_dim[0], num_heads, x_dim[1], dim_head});
// cache_seq_len + seq_len if cache else seq_len // cache_seq_len + seq_len if cache else seq_len
auto out_seq_len = x_dim[1]; auto out_seq_len = x_dim[1];
...@@ -193,11 +254,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -193,11 +254,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
x_dim[0], x_dim[0],
c_dim[1])); // batch_size c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2], PADDLE_ENFORCE_EQ(c_dim[2],
y_dim[1], num_heads,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num " "The third dim of CacheKV must be equal with num "
"head %d, but got %d", "head %d, but got %d",
y_dim[1], num_heads,
c_dim[2])); // num_head c_dim[2])); // num_head
// In compile stage, input seq_len can be -1, in that case // In compile stage, input seq_len can be -1, in that case
// c_dim[3] may < 0 in while // c_dim[3] may < 0 in while
...@@ -209,12 +270,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -209,12 +270,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"The forth dim of CacheKV must be greater than 0, but got %d", "The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len c_dim[3])); // cache_seq_len
} }
PADDLE_ENFORCE_EQ(c_dim[4], PADDLE_ENFORCE_EQ(c_dim[4],
y_dim[2], dim_head,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head " "The fifth dim of CacheKV must be equal with head "
"size %d, but got %d", "size %d, but got %d",
y_dim[2], dim_head,
c_dim[4])); // head_size c_dim[4])); // head_size
out_seq_len += c_dim[3]; out_seq_len += c_dim[3];
...@@ -224,25 +286,26 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -224,25 +286,26 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
} }
// [batch, num_head, seq_len, out_seq_len] // [batch, num_head, seq_len, out_seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); ctx->SetOutputDim("QKOut", {x_dim[0], num_heads, x_dim[1], out_seq_len});
if (ctx->HasInput("SrcMask")) { if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", ctx->SetOutputDim("SrcMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len}); {x_dim[0], num_heads, x_dim[1], out_seq_len});
} }
// the same as QKOut's shape. // the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut", ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len}); {x_dim[0], num_heads, x_dim[1], out_seq_len});
if (ctx->Attrs().Get<bool>("is_test") == false) { if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("AttnDropoutMaskOut", ctx->SetOutputDim("AttnDropoutMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len}); {x_dim[0], num_heads, x_dim[1], out_seq_len});
} }
ctx->SetOutputDim("SoftmaxOut", ctx->SetOutputDim("SoftmaxOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len}); {x_dim[0], num_heads, x_dim[1], out_seq_len});
// [batch_size, num_heads, seq_len, head_dim] // [batch_size, num_heads, seq_len, head_dim]
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); ctx->SetOutputDim("QKTVOut", {x_dim[0], num_heads, x_dim[1], dim_head});
// [batch_size, seq_len, number of heads*head size] // [batch_size, seq_len, number of heads*head size]
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], num_heads, dim_head});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<bool>("is_test") == false) { if (ctx->Attrs().Get<bool>("is_test") == false) {
...@@ -315,6 +378,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -315,6 +378,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("CacheKVOut", "The udpated cache KV."); AddOutput("CacheKVOut", "The udpated cache KV.");
AddOutput("Y", "Result after attention."); AddOutput("Y", "Result after attention.");
AddAttr<int>("num_heads", "The number head for multi_head_attention.")
.SetDefault(-1);
AddAttr<bool>("transpose_qkv_wb",
"The qkv_w shape is (h, 3h), do transpose to it.")
.SetDefault(false);
AddAttr<bool>("pre_layer_norm", AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, " "if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. " "else, uses post_layer_norm architecuture. "
......
...@@ -25,9 +25,12 @@ limitations under the License. */ ...@@ -25,9 +25,12 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/distributed/collective/process_group_nccl.h"
...@@ -87,8 +90,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -87,8 +90,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance"); auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut"); auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
const auto num_heads = ctx.Attr<int>("num_heads");
const auto transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
// x: qkv's input [batch_size, seq_len, dim_embed] // x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
// y: qkv's weight: [3, num_head, dim_head, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed]
// if transpose_qkv_wb is True
// y: qkv's weight: [dim_embed, 3 * dim_embed]
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW"); auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias"); auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut"); auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
...@@ -206,8 +215,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -206,8 +215,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
int max_seq_len = input_x_dims[1]; int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1]; int num_head;
int dim_head = qkv_w_dims[2]; int dim_head;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
num_head = num_heads;
dim_head = dim_embed / num_head;
}
int bsz_seq = batch_size * max_seq_len; int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head; int hidden_size = num_head * dim_head;
...@@ -222,9 +239,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -222,9 +239,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
compute_bias = false; compute_bias = false;
} }
// (transA, transB, compute_bias) = (false, true, true) // (transA, transB, compute_bias) = (false, true, true)
bool transB = transpose_qkv_wb ? false : true;
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
false, false,
true, transB,
bsz_seq, bsz_seq,
output_size, output_size,
input_size, input_size,
...@@ -288,6 +306,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -288,6 +306,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward( qkv_compute.ComputeForward(
qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
} }
if (transpose_qkv_wb) {
// resize the output for fmha compute
qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
if (qkv_bias == nullptr) { if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out, fmha_ref_compute.ComputeForward(*qkv_out,
cache_kv, cache_kv,
...@@ -316,6 +341,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -316,6 +341,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
fmha_out); fmha_out);
} }
if (transpose_qkv_wb) {
// resize the output back to make the shape compatible with infer shape
qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
// fmha_out: [batch_size, seq_len, num_head, head_dim] // fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim] // weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim]
...@@ -374,6 +405,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -374,6 +405,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
const int num_heads = ctx.Attr<int>("num_heads");
const bool transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon"); const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
...@@ -544,8 +577,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -544,8 +577,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int batch_size = input_x_dims[0]; int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1]; int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1]; int num_head;
int dim_head = qkv_w_dims[2]; int dim_head;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
num_head = num_heads;
dim_head = dim_embed / num_head;
}
int bsz_seq = batch_size * max_seq_len; int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head; int hidden_size = num_head * dim_head;
...@@ -562,7 +602,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -562,7 +602,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
} }
bool transA = false; bool transA = false;
bool transB = true; bool transB = transpose_qkv_wb ? false : true;
bool compute_qkv_bias = qkv_bias ? true : false; bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>( auto layer_norm_compute = AttnLayerNorm<T>(
ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
...@@ -655,6 +695,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -655,6 +695,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight, d_out_linear_weight,
nullptr); nullptr);
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize(
{batch_size, max_seq_len, 3, num_head, dim_head});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
}
if (qkv_bias != nullptr) { if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2, fmha_ref_compute.ComputeBackward(*transpose_out_2,
has_attn_dropout ? src_mask : nullptr, has_attn_dropout ? src_mask : nullptr,
...@@ -691,6 +740,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -691,6 +740,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_qkv_out); d_qkv_out);
} }
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
}
if (pre_layer_norm) { if (pre_layer_norm) {
auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean"); auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance"); auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
......
...@@ -103,6 +103,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -103,6 +103,7 @@ class TestFusedAttentionOp(OpTest):
self.query_length, self.query_length,
self.query_length, self.query_length,
) )
self.transpose_qkv_wb = False
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand( self.query = np.random.rand(
...@@ -265,6 +266,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -265,6 +266,7 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = np.concatenate( qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()) (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
) )
if not self.transpose_qkv_wb:
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor( out_linear_bias = paddle.to_tensor(
...@@ -276,12 +278,20 @@ class TestFusedAttentionOp(OpTest): ...@@ -276,12 +278,20 @@ class TestFusedAttentionOp(OpTest):
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
if not self.transpose_qkv_wb:
q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
else:
q_proj_weight = q_proj_weight.numpy()
k_proj_weight = k_proj_weight.numpy()
v_proj_weight = v_proj_weight.numpy()
concatenate_axis = 1 if self.transpose_qkv_wb else 0
qkv_weight = np.concatenate( qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight) (q_proj_weight, k_proj_weight, v_proj_weight), axis=concatenate_axis
) )
if not self.transpose_qkv_wb:
qkv_weight = qkv_weight.reshape( qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim) (3, self.num_heads, self.head_dim, self.embed_dim)
) )
...@@ -317,6 +327,8 @@ class TestFusedAttentionOp(OpTest): ...@@ -317,6 +327,8 @@ class TestFusedAttentionOp(OpTest):
self.dropout_prob, self.dropout_prob,
self.attn_dropout_prob, self.attn_dropout_prob,
ln2_epsilon, ln2_epsilon,
num_heads=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
) )
if self.has_cache_kv: if self.has_cache_kv:
...@@ -344,6 +356,19 @@ class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): ...@@ -344,6 +356,19 @@ class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
self.bias_attr = False self.bias_attr = False
class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionOp):
def config(self):
super().config()
self.transpose_qkv_wb = True
class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionOp):
def config(self):
super().config()
self.transpose_qkv_wb = True
self.bias_attr = False
class TestFusedAttentionOpPreLn(TestFusedAttentionOp): class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self): def config(self):
super().config() super().config()
......
...@@ -81,6 +81,8 @@ def compute_reference( ...@@ -81,6 +81,8 @@ def compute_reference(
qkv_bias, qkv_bias,
out_linear_weight, out_linear_weight,
out_linear_bias, out_linear_bias,
num_head,
transpose_qkv_wb,
): ):
batch_size = query.shape[0] batch_size = query.shape[0]
seq_len = query.shape[1] seq_len = query.shape[1]
...@@ -93,8 +95,8 @@ def compute_reference( ...@@ -93,8 +95,8 @@ def compute_reference(
if pre_layer_norm: if pre_layer_norm:
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias) ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1] head_dim = embed_dim // num_head
head_dim = qkv_weight.shape[2] if not transpose_qkv_wb:
# embed_dim, 3, num_heads, self.head_dim # embed_dim, 3, num_heads, self.head_dim
qkv_weight = qkv_weight.transpose((3, 0, 1, 2)) qkv_weight = qkv_weight.transpose((3, 0, 1, 2))
qkv_weight = qkv_weight.reshape( qkv_weight = qkv_weight.reshape(
...@@ -106,6 +108,13 @@ def compute_reference( ...@@ -106,6 +108,13 @@ def compute_reference(
qkv_bias = qkv_bias.reshape( qkv_bias = qkv_bias.reshape(
qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2] qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2]
) )
else:
assert len(qkv_weight.shape) == 2
assert qkv_weight.shape[0] * 3 == qkv_weight.shape[1]
if qkv_bias is not None:
assert len(qkv_bias.shape) == 1
assert qkv_bias.shape[0] == qkv_weight.shape[1]
if pre_layer_norm: if pre_layer_norm:
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight) qkv = fc(ln_out, qkv_weight)
...@@ -189,6 +198,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -189,6 +198,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.setPreLn() self.setPreLn()
self.setAttnMask() self.setAttnMask()
self.setBiasAttr() self.setBiasAttr()
self.setTransposeWAndB()
self.config() self.config()
self.generate_input_data() self.generate_input_data()
...@@ -209,6 +219,9 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -209,6 +219,9 @@ class TestFusedAttentionAPI(unittest.TestCase):
def setBiasAttr(self): def setBiasAttr(self):
self.bias_attr = None self.bias_attr = None
def setTransposeWAndB(self):
self.transpose_qkv_wb = False
def setPreLn(self): def setPreLn(self):
self.pre_layer_norm = False self.pre_layer_norm = False
...@@ -284,6 +297,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -284,6 +297,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.bias_attr, self.bias_attr,
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
transpose_qkv_wb=self.transpose_qkv_wb,
) )
if self.bias_attr is not False: if self.bias_attr is not False:
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype( qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
...@@ -323,6 +337,8 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -323,6 +337,8 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn_qkv_bias, fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn.linear_weight.numpy(),
fused_attn_linear_bias, fused_attn_linear_bias,
num_head=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
ref_out, out.numpy(), rtol=self.rtol, atol=self.atol ref_out, out.numpy(), rtol=self.rtol, atol=self.atol
...@@ -346,6 +362,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -346,6 +362,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.bias_attr, self.bias_attr,
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
transpose_qkv_wb=self.transpose_qkv_wb,
) )
x = paddle.static.data( x = paddle.static.data(
...@@ -562,6 +579,8 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -562,6 +579,8 @@ class TestFusedAttentionAPI(unittest.TestCase):
qkv_bias, qkv_bias,
linear_weight, linear_weight,
linear_bias, linear_bias,
num_head=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
) )
np.testing.assert_allclose(ref_out, out, rtol=self.rtol, atol=self.atol) np.testing.assert_allclose(ref_out, out, rtol=self.rtol, atol=self.atol)
...@@ -583,5 +602,18 @@ class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI): ...@@ -583,5 +602,18 @@ class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
self.bias_attr = False self.bias_attr = False
class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionAPI):
def setTransposeWAndB(self):
self.transpose_qkv_wb = True
class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionAPI):
def setTransposeWAndB(self):
self.transpose_qkv_wb = True
def setBiasAttr(self):
self.bias_attr = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -482,6 +482,8 @@ def fused_multi_head_attention( ...@@ -482,6 +482,8 @@ def fused_multi_head_attention(
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1, ring_id=-1,
add_residual=True, add_residual=True,
num_heads=-1,
transpose_qkv_wb=False,
name=None, name=None,
): ):
r""" r"""
...@@ -567,6 +569,8 @@ def fused_multi_head_attention( ...@@ -567,6 +569,8 @@ def fused_multi_head_attention(
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional): Whether add residual at the end. Default is True. add_residual (bool, optional): Whether add residual at the end. Default is True.
num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -617,6 +621,7 @@ def fused_multi_head_attention( ...@@ -617,6 +621,7 @@ def fused_multi_head_attention(
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
if not transpose_qkv_wb:
assert ( assert (
len(qkv_weight.shape) == 4 len(qkv_weight.shape) == 4
), "The dims of the shape of qkv_weight should be 4." ), "The dims of the shape of qkv_weight should be 4."
...@@ -629,9 +634,35 @@ def fused_multi_head_attention( ...@@ -629,9 +634,35 @@ def fused_multi_head_attention(
if ring_id == -1: if ring_id == -1:
# under mp, the num head will be split, this equation will not hold # under mp, the num head will be split, this equation will not hold
assert ( assert (
qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[3] qkv_weight.shape[1] * qkv_weight.shape[2]
== qkv_weight.shape[3]
), "embed_dim must be divisible by num_heads." ), "embed_dim must be divisible by num_heads."
else:
assert (
num_heads > 0
), "When enable transpose_qkv_wb, the num_heads should be provided and greater than 0."
assert len(qkv_weight.shape) == 2, (
"When enable transpose_qkv_wb, the dims of the shape of qkv_weight "
"should be 2 when enable transpose_qkv_wb."
)
if ring_id == -1:
# under mp, the num head will be split, this equation will not hold
assert qkv_weight.shape[1] == 3 * qkv_weight.shape[0], (
"When enable transpose_qkv_wb, the shape of qkv_weight should be "
"[embed_dim, 3 * embed_dim] when enable transpose_qkv_wb."
)
assert qkv_weight.shape[0] == x.shape[2], (
"When enable transpose_qkv_wb, the 1st dim of qkv_weight and 2nd dim of x "
"should be the same, i.e., embed_dim."
)
if qkv_bias is not None:
assert (
len(qkv_bias.shape) == 1
), "When enable transpose_qkv_wb, the dims of the shape of qkv_bias should be 1."
assert qkv_bias.shape[0] == qkv_weight.shape[1], (
"When enable transpose_qkv_wb, the 1st dim of qkv_bias and 2nd dim of "
"qkv_weight should be the same, i.e., embed_dim."
)
( (
_, _,
_, _,
...@@ -665,6 +696,10 @@ def fused_multi_head_attention( ...@@ -665,6 +696,10 @@ def fused_multi_head_attention(
linear_bias, linear_bias,
ln_scale, ln_scale,
ln_bias, ln_bias,
'num_heads',
num_heads,
'transpose_qkv_wb',
transpose_qkv_wb,
'pre_layer_norm', 'pre_layer_norm',
pre_layer_norm, pre_layer_norm,
'epsilon', 'epsilon',
...@@ -754,6 +789,8 @@ def fused_multi_head_attention( ...@@ -754,6 +789,8 @@ def fused_multi_head_attention(
'dropout_implementation': mode, 'dropout_implementation': mode,
'add_residual': add_residual, 'add_residual': add_residual,
'ring_id': ring_id, 'ring_id': ring_id,
'num_heads': num_heads,
'transpose_qkv_wb': transpose_qkv_wb,
} }
# set outputs # set outputs
......
...@@ -246,6 +246,11 @@ class FusedMultiHeadAttention(Layer): ...@@ -246,6 +246,11 @@ class FusedMultiHeadAttention(Layer):
epsilon (float, optional): The small value added to the variance to prevent epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05. division by zero. Default: 1e-05.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel. nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
transpose_qkv_wb (bool, optional): Support input qkv matmul weight shape as
[hidden_size, 3 * hidden_size] and qkv matmul bias shape as [3 * hidden_size].
Will transpose the weight to [3, num_head, head_dim, hidden_size] and transpose bias to
[3, num_head, hidden_size] in the fused_attention_op. Only support for GPU for now.
The default value is False, which is not do transpose to qkv_w and qkv_b.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel. ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
Examples: Examples:
...@@ -283,6 +288,7 @@ class FusedMultiHeadAttention(Layer): ...@@ -283,6 +288,7 @@ class FusedMultiHeadAttention(Layer):
epsilon=1e-5, epsilon=1e-5,
nranks=1, nranks=1,
ring_id=-1, ring_id=-1,
transpose_qkv_wb=False,
name=None, name=None,
): ):
super().__init__() super().__init__()
...@@ -315,22 +321,31 @@ class FusedMultiHeadAttention(Layer): ...@@ -315,22 +321,31 @@ class FusedMultiHeadAttention(Layer):
# tensor model parallel # tensor model parallel
assert num_heads % nranks == 0 assert num_heads % nranks == 0
num_heads = num_heads // nranks self.num_heads = num_heads // nranks
self.transpose_qkv_wb = transpose_qkv_wb
if self.transpose_qkv_wb:
# For tensor model parallel, use num_head * head_dim to compute the real shape.
qkv_wight_shape = [embed_dim, 3 * self.num_heads * self.head_dim]
qkv_bias_shape = [3 * self.num_heads * self.head_dim]
else:
qkv_wight_shape = [3, self.num_heads, self.head_dim, embed_dim]
qkv_bias_shape = [3, self.num_heads, self.head_dim]
self.qkv_weight = self.create_parameter( self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim], shape=qkv_wight_shape,
attr=qkv_weight_attr, attr=qkv_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False, is_bias=False,
) )
self.qkv_bias = self.create_parameter( self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim], shape=qkv_bias_shape,
attr=qkv_bias_attr, attr=qkv_bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
) )
self.linear_weight = self.create_parameter( self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim], shape=[self.num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr, attr=linear_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False, is_bias=False,
...@@ -436,6 +451,8 @@ class FusedMultiHeadAttention(Layer): ...@@ -436,6 +451,8 @@ class FusedMultiHeadAttention(Layer):
ln_epsilon=self._epsilon, ln_epsilon=self._epsilon,
training=self.training, training=self.training,
ring_id=self._ring_id, ring_id=self._ring_id,
num_heads=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
name=self.name, name=self.name,
) )
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册