From ec857b850dd2f019ab3e658a920a878b8ca53630 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 5 Jan 2023 10:58:30 +0800 Subject: [PATCH] Add transpose_qkv_wb flags to the fused_attention_op. (#49494) --- .../operators/fused/fused_attention_op.cc | 144 +++++++++++++----- .../operators/fused/fused_attention_op.cu | 69 ++++++++- .../unittests/test_fused_attention_op.py | 41 ++++- .../unittests/test_fused_attention_op_api.py | 56 +++++-- .../nn/functional/fused_transformer.py | 65 ++++++-- .../incubate/nn/layer/fused_transformer.py | 25 ++- 6 files changed, 318 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 6b1f533b34..7d00dda194 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -108,10 +108,77 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); + int num_heads = ctx->Attrs().Get("num_heads"); + bool transpose_qkv_wb = ctx->Attrs().Get("transpose_qkv_wb"); + // x: qkv's input [batch_size, seq_len, dim_embed] + // if transpose_qkv_wb is False // y: qkv's weight: [3, num_head, dim_head, dim_embed] + // if transpose_qkv_wb is True + // y: qkv's weight: [dim_embed, 3 * dim_embed] auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputDim("QKVW"); + int dim_head; + int hidden_size; + if (transpose_qkv_wb) { + PADDLE_ENFORCE_EQ(y_dim.size(), + 2, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 2 if enable" + "transpose_qkv_wb: (dim_embed, 3 * dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_GT(num_heads, + 0, + platform::errors::InvalidArgument( + "The num_heads must be provided and greater than 0 " + "if enable transpose_qkv_wb, but we got %d.", + num_heads)); + PADDLE_ENFORCE_EQ(y_dim[0] % num_heads, + 0, + platform::errors::InvalidArgument( + "First dim of qkv_w must be divisible by num heads " + "if enable transpose_qkv_wb, but receive first " + "dim of qkv_w is %d and num_heads is %d.", + y_dim[0], + num_heads)); + if (ctx->Attrs().Get("ring_id") == -1) { + PADDLE_ENFORCE_EQ(y_dim[0] * 3, + y_dim[1], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 2" + "(dim_embed, 3 * dim_embed).")); + } + dim_head = y_dim[0] / num_heads; + hidden_size = y_dim[0]; + } else { + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4 if not" + "enable transpose_qkv_wb: (3, num_head, dim_head, " + "dim_embed), but received [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ(y_dim[0], + 3, + platform::errors::InvalidArgument( + "First dim of qkv_w must be 3 if disable " + "transpose_qkv_wb, but we got %d.", + y_dim[0])); + if (ctx->Attrs().Get("ring_id") == -1) { + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], + y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + } + num_heads = y_dim[1]; + dim_head = y_dim[2]; + hidden_size = y_dim[3]; + } PADDLE_ENFORCE_EQ( x_dim.size(), 3, @@ -120,34 +187,18 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "but received dimensions of" "Input is [%d]", x_dim.size())); - PADDLE_ENFORCE_EQ(y_dim.size(), - 4, - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "but received dimensions of" - "Input is [%d]", - y_dim.size())); + PADDLE_ENFORCE_EQ(x_dim[2], - y_dim[3], + hidden_size, platform::errors::InvalidArgument( - "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "ShapeError: the dimension of x_dim[2] and y_dim[3] " + "(y_dim[1] if enable transpose_qkv_w) " "must be equal. But received: the shape " "of input x = [%s], and the shape of " "input qkv_weight = [%s]", x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); @@ -157,17 +208,27 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); } - // [batch_size, seq_len, 3, num_head, head_size] - ctx->SetOutputDim("QKVOut", - {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); - if (ctx->HasInput("QKVBias")) { - ctx->SetOutputDim("QKVBiasOut", - {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + if (transpose_qkv_wb) { + // [batch_size, seq_len, 3 * hidden_size] + ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size}); + + if (ctx->HasInput("QKVBias")) { + ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size}); + } + } else { + // [batch_size, seq_len, 3, num_head, head_size] + ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3, num_heads, dim_head}); + + if (ctx->HasInput("QKVBias")) { + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], 3, num_heads, dim_head}); + } } + // [3, batch_size, num_head, seq_len, head_size] ctx->SetOutputDim("TransposeOut2", - {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + {3, x_dim[0], num_heads, x_dim[1], dim_head}); // cache_seq_len + seq_len if cache else seq_len auto out_seq_len = x_dim[1]; @@ -193,11 +254,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { x_dim[0], c_dim[1])); // batch_size PADDLE_ENFORCE_EQ(c_dim[2], - y_dim[1], + num_heads, paddle::platform::errors::InvalidArgument( "The third dim of CacheKV must be equal with num " "head %d, but got %d", - y_dim[1], + num_heads, c_dim[2])); // num_head // In compile stage, input seq_len can be -1, in that case // c_dim[3] may < 0 in while @@ -209,12 +270,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "The forth dim of CacheKV must be greater than 0, but got %d", c_dim[3])); // cache_seq_len } + PADDLE_ENFORCE_EQ(c_dim[4], - y_dim[2], + dim_head, paddle::platform::errors::InvalidArgument( "The fifth dim of CacheKV must be equal with head " "size %d, but got %d", - y_dim[2], + dim_head, c_dim[4])); // head_size out_seq_len += c_dim[3]; @@ -224,25 +286,26 @@ class FusedAttentionOp : public framework::OperatorWithKernel { } // [batch, num_head, seq_len, out_seq_len] - ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); + ctx->SetOutputDim("QKOut", {x_dim[0], num_heads, x_dim[1], out_seq_len}); if (ctx->HasInput("SrcMask")) { ctx->SetOutputDim("SrcMaskOut", - {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); + {x_dim[0], num_heads, x_dim[1], out_seq_len}); } // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", - {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); + {x_dim[0], num_heads, x_dim[1], out_seq_len}); if (ctx->Attrs().Get("is_test") == false) { ctx->SetOutputDim("AttnDropoutMaskOut", - {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); + {x_dim[0], num_heads, x_dim[1], out_seq_len}); } ctx->SetOutputDim("SoftmaxOut", - {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); + {x_dim[0], num_heads, x_dim[1], out_seq_len}); // [batch_size, num_heads, seq_len, head_dim] - ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + ctx->SetOutputDim("QKTVOut", {x_dim[0], num_heads, x_dim[1], dim_head}); // [batch_size, seq_len, number of heads*head size] - ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], num_heads, dim_head}); + ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); if (ctx->Attrs().Get("is_test") == false) { @@ -315,6 +378,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("CacheKVOut", "The udpated cache KV."); AddOutput("Y", "Result after attention."); + AddAttr("num_heads", "The number head for multi_head_attention.") + .SetDefault(-1); + AddAttr("transpose_qkv_wb", + "The qkv_w shape is (h, 3h), do transpose to it.") + .SetDefault(false); AddAttr("pre_layer_norm", "if true, the attention op uses pre_layer_norm architecure, " "else, uses post_layer_norm architecuture. " diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index d963e73965..91dbf71bbc 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -25,9 +25,12 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group_nccl.h" @@ -87,8 +90,14 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *ln_var = ctx.Output("LnVariance"); auto *ln_out = ctx.Output("LnOut"); + const auto num_heads = ctx.Attr("num_heads"); + const auto transpose_qkv_wb = ctx.Attr("transpose_qkv_wb"); + // x: qkv's input [batch_size, seq_len, dim_embed] + // if transpose_qkv_wb is False // y: qkv's weight: [3, num_head, dim_head, dim_embed] + // if transpose_qkv_wb is True + // y: qkv's weight: [dim_embed, 3 * dim_embed] auto *qkv_weight = ctx.Input("QKVW"); auto *qkv_bias = ctx.Input("QKVBias"); auto *qkv_out = ctx.Output("QKVOut"); @@ -206,8 +215,16 @@ class FusedAttentionOpKernel : public framework::OpKernel { int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; - int num_head = qkv_w_dims[1]; - int dim_head = qkv_w_dims[2]; + int num_head; + int dim_head; + // get num_head and dim_head in two different ways + if (!transpose_qkv_wb) { + num_head = qkv_w_dims[1]; + dim_head = qkv_w_dims[2]; + } else { + num_head = num_heads; + dim_head = dim_embed / num_head; + } int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; @@ -222,9 +239,10 @@ class FusedAttentionOpKernel : public framework::OpKernel { compute_bias = false; } // (transA, transB, compute_bias) = (false, true, true) + bool transB = transpose_qkv_wb ? false : true; auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, - true, + transB, bsz_seq, output_size, input_size, @@ -288,6 +306,13 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_compute.ComputeForward( qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); } + + if (transpose_qkv_wb) { + // resize the output for fmha compute + qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head}); + qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head}); + } + if (qkv_bias == nullptr) { fmha_ref_compute.ComputeForward(*qkv_out, cache_kv, @@ -316,6 +341,12 @@ class FusedAttentionOpKernel : public framework::OpKernel { fmha_out); } + if (transpose_qkv_wb) { + // resize the output back to make the shape compatible with infer shape + qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); + qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); + } + // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] @@ -374,6 +405,8 @@ class FusedAttentionGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; + const int num_heads = ctx.Attr("num_heads"); + const bool transpose_qkv_wb = ctx.Attr("transpose_qkv_wb"); const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); const float ln2epsilon = ctx.Attr("ln_epsilon"); @@ -544,8 +577,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { int batch_size = input_x_dims[0]; int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; - int num_head = qkv_w_dims[1]; - int dim_head = qkv_w_dims[2]; + int num_head; + int dim_head; + if (!transpose_qkv_wb) { + num_head = qkv_w_dims[1]; + dim_head = qkv_w_dims[2]; + } else { + num_head = num_heads; + dim_head = dim_embed / num_head; + } int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; @@ -562,7 +602,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { } bool transA = false; - bool transB = true; + bool transB = transpose_qkv_wb ? false : true; bool compute_qkv_bias = qkv_bias ? true : false; auto layer_norm_compute = AttnLayerNorm( ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); @@ -655,6 +695,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_weight, nullptr); + if (transpose_qkv_wb) { + if (compute_qkv_bias) { + d_qkv_bias_out->Resize( + {batch_size, max_seq_len, 3, num_head, dim_head}); + } else { + d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head}); + } + } + if (qkv_bias != nullptr) { fmha_ref_compute.ComputeBackward(*transpose_out_2, has_attn_dropout ? src_mask : nullptr, @@ -691,6 +740,14 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_qkv_out); } + if (transpose_qkv_wb) { + if (compute_qkv_bias) { + d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); + } else { + d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); + } + } + if (pre_layer_norm) { auto *ln_mean = ctx.Input("LnMean"); auto *ln_var = ctx.Input("LnVariance"); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index d09d0b4fbe..04b07e9c35 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -103,6 +103,7 @@ class TestFusedAttentionOp(OpTest): self.query_length, self.query_length, ) + self.transpose_qkv_wb = False def generate_input_data(self): self.query = np.random.rand( @@ -265,7 +266,8 @@ class TestFusedAttentionOp(OpTest): qkv_bias = np.concatenate( (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()) ) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + if not self.transpose_qkv_wb: + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) out_linear_bias = paddle.to_tensor( self.out_proj.bias, stop_gradient=False @@ -276,15 +278,23 @@ class TestFusedAttentionOp(OpTest): ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) - q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) - k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) - v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + if not self.transpose_qkv_wb: + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + else: + q_proj_weight = q_proj_weight.numpy() + k_proj_weight = k_proj_weight.numpy() + v_proj_weight = v_proj_weight.numpy() + + concatenate_axis = 1 if self.transpose_qkv_wb else 0 qkv_weight = np.concatenate( - (q_proj_weight, k_proj_weight, v_proj_weight) - ) - qkv_weight = qkv_weight.reshape( - (3, self.num_heads, self.head_dim, self.embed_dim) + (q_proj_weight, k_proj_weight, v_proj_weight), axis=concatenate_axis ) + if not self.transpose_qkv_wb: + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim) + ) x = paddle.to_tensor(self.query, stop_gradient=False) cache_kv = None @@ -317,6 +327,8 @@ class TestFusedAttentionOp(OpTest): self.dropout_prob, self.attn_dropout_prob, ln2_epsilon, + num_heads=self.num_heads, + transpose_qkv_wb=self.transpose_qkv_wb, ) if self.has_cache_kv: @@ -344,6 +356,19 @@ class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): self.bias_attr = False +class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionOp): + def config(self): + super().config() + self.transpose_qkv_wb = True + + +class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionOp): + def config(self): + super().config() + self.transpose_qkv_wb = True + self.bias_attr = False + + class TestFusedAttentionOpPreLn(TestFusedAttentionOp): def config(self): super().config() diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index 0917e8b96a..decba63d49 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -81,6 +81,8 @@ def compute_reference( qkv_bias, out_linear_weight, out_linear_bias, + num_head, + transpose_qkv_wb, ): batch_size = query.shape[0] seq_len = query.shape[1] @@ -93,19 +95,26 @@ def compute_reference( if pre_layer_norm: ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias) - num_head = qkv_weight.shape[1] - head_dim = qkv_weight.shape[2] - # embed_dim, 3, num_heads, self.head_dim - qkv_weight = qkv_weight.transpose((3, 0, 1, 2)) - qkv_weight = qkv_weight.reshape( - qkv_weight.shape[0], - qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3], - ) - - if qkv_bias is not None: - qkv_bias = qkv_bias.reshape( - qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2] + head_dim = embed_dim // num_head + if not transpose_qkv_wb: + # embed_dim, 3, num_heads, self.head_dim + qkv_weight = qkv_weight.transpose((3, 0, 1, 2)) + qkv_weight = qkv_weight.reshape( + qkv_weight.shape[0], + qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3], ) + + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape( + qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2] + ) + else: + assert len(qkv_weight.shape) == 2 + assert qkv_weight.shape[0] * 3 == qkv_weight.shape[1] + if qkv_bias is not None: + assert len(qkv_bias.shape) == 1 + assert qkv_bias.shape[0] == qkv_weight.shape[1] + if pre_layer_norm: ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) qkv = fc(ln_out, qkv_weight) @@ -189,6 +198,7 @@ class TestFusedAttentionAPI(unittest.TestCase): self.setPreLn() self.setAttnMask() self.setBiasAttr() + self.setTransposeWAndB() self.config() self.generate_input_data() @@ -209,6 +219,9 @@ class TestFusedAttentionAPI(unittest.TestCase): def setBiasAttr(self): self.bias_attr = None + def setTransposeWAndB(self): + self.transpose_qkv_wb = False + def setPreLn(self): self.pre_layer_norm = False @@ -284,6 +297,7 @@ class TestFusedAttentionAPI(unittest.TestCase): self.bias_attr, self.weight_attr, self.bias_attr, + transpose_qkv_wb=self.transpose_qkv_wb, ) if self.bias_attr is not False: qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype( @@ -323,6 +337,8 @@ class TestFusedAttentionAPI(unittest.TestCase): fused_attn_qkv_bias, fused_attn.linear_weight.numpy(), fused_attn_linear_bias, + num_head=self.num_heads, + transpose_qkv_wb=self.transpose_qkv_wb, ) np.testing.assert_allclose( ref_out, out.numpy(), rtol=self.rtol, atol=self.atol @@ -346,6 +362,7 @@ class TestFusedAttentionAPI(unittest.TestCase): self.bias_attr, self.weight_attr, self.bias_attr, + transpose_qkv_wb=self.transpose_qkv_wb, ) x = paddle.static.data( @@ -562,6 +579,8 @@ class TestFusedAttentionAPI(unittest.TestCase): qkv_bias, linear_weight, linear_bias, + num_head=self.num_heads, + transpose_qkv_wb=self.transpose_qkv_wb, ) np.testing.assert_allclose(ref_out, out, rtol=self.rtol, atol=self.atol) @@ -583,5 +602,18 @@ class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI): self.bias_attr = False +class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionAPI): + def setTransposeWAndB(self): + self.transpose_qkv_wb = True + + +class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionAPI): + def setTransposeWAndB(self): + self.transpose_qkv_wb = True + + def setBiasAttr(self): + self.bias_attr = False + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 61270f86d3..d8dfdb0242 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -482,6 +482,8 @@ def fused_multi_head_attention( mode='upscale_in_train', ring_id=-1, add_residual=True, + num_heads=-1, + transpose_qkv_wb=False, name=None, ): r""" @@ -567,6 +569,8 @@ def fused_multi_head_attention( - inference: out = input * (1.0 - p) ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp add_residual (bool, optional): Whether add residual at the end. Default is True. + num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb. + transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -617,21 +621,48 @@ def fused_multi_head_attention( # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out - assert ( - len(qkv_weight.shape) == 4 - ), "The dims of the shape of qkv_weight should be 4." - assert ( - qkv_weight.shape[0] == 3 - ), "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." - assert ( - qkv_weight.shape[3] == x.shape[2] - ), "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." - if ring_id == -1: - # under mp, the num head will be split, this equation will not hold + if not transpose_qkv_wb: assert ( - qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[3] - ), "embed_dim must be divisible by num_heads." - + len(qkv_weight.shape) == 4 + ), "The dims of the shape of qkv_weight should be 4." + assert ( + qkv_weight.shape[0] == 3 + ), "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." + assert ( + qkv_weight.shape[3] == x.shape[2] + ), "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." + if ring_id == -1: + # under mp, the num head will be split, this equation will not hold + assert ( + qkv_weight.shape[1] * qkv_weight.shape[2] + == qkv_weight.shape[3] + ), "embed_dim must be divisible by num_heads." + else: + assert ( + num_heads > 0 + ), "When enable transpose_qkv_wb, the num_heads should be provided and greater than 0." + assert len(qkv_weight.shape) == 2, ( + "When enable transpose_qkv_wb, the dims of the shape of qkv_weight " + "should be 2 when enable transpose_qkv_wb." + ) + if ring_id == -1: + # under mp, the num head will be split, this equation will not hold + assert qkv_weight.shape[1] == 3 * qkv_weight.shape[0], ( + "When enable transpose_qkv_wb, the shape of qkv_weight should be " + "[embed_dim, 3 * embed_dim] when enable transpose_qkv_wb." + ) + assert qkv_weight.shape[0] == x.shape[2], ( + "When enable transpose_qkv_wb, the 1st dim of qkv_weight and 2nd dim of x " + "should be the same, i.e., embed_dim." + ) + if qkv_bias is not None: + assert ( + len(qkv_bias.shape) == 1 + ), "When enable transpose_qkv_wb, the dims of the shape of qkv_bias should be 1." + assert qkv_bias.shape[0] == qkv_weight.shape[1], ( + "When enable transpose_qkv_wb, the 1st dim of qkv_bias and 2nd dim of " + "qkv_weight should be the same, i.e., embed_dim." + ) ( _, _, @@ -665,6 +696,10 @@ def fused_multi_head_attention( linear_bias, ln_scale, ln_bias, + 'num_heads', + num_heads, + 'transpose_qkv_wb', + transpose_qkv_wb, 'pre_layer_norm', pre_layer_norm, 'epsilon', @@ -754,6 +789,8 @@ def fused_multi_head_attention( 'dropout_implementation': mode, 'add_residual': add_residual, 'ring_id': ring_id, + 'num_heads': num_heads, + 'transpose_qkv_wb': transpose_qkv_wb, } # set outputs diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 78fc72794e..b49cc50086 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -246,6 +246,11 @@ class FusedMultiHeadAttention(Layer): epsilon (float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel. + transpose_qkv_wb (bool, optional): Support input qkv matmul weight shape as + [hidden_size, 3 * hidden_size] and qkv matmul bias shape as [3 * hidden_size]. + Will transpose the weight to [3, num_head, head_dim, hidden_size] and transpose bias to + [3, num_head, hidden_size] in the fused_attention_op. Only support for GPU for now. + The default value is False, which is not do transpose to qkv_w and qkv_b. ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel. Examples: @@ -283,6 +288,7 @@ class FusedMultiHeadAttention(Layer): epsilon=1e-5, nranks=1, ring_id=-1, + transpose_qkv_wb=False, name=None, ): super().__init__() @@ -315,22 +321,31 @@ class FusedMultiHeadAttention(Layer): # tensor model parallel assert num_heads % nranks == 0 - num_heads = num_heads // nranks + self.num_heads = num_heads // nranks + + self.transpose_qkv_wb = transpose_qkv_wb + if self.transpose_qkv_wb: + # For tensor model parallel, use num_head * head_dim to compute the real shape. + qkv_wight_shape = [embed_dim, 3 * self.num_heads * self.head_dim] + qkv_bias_shape = [3 * self.num_heads * self.head_dim] + else: + qkv_wight_shape = [3, self.num_heads, self.head_dim, embed_dim] + qkv_bias_shape = [3, self.num_heads, self.head_dim] self.qkv_weight = self.create_parameter( - shape=[3, num_heads, self.head_dim, embed_dim], + shape=qkv_wight_shape, attr=qkv_weight_attr, dtype=self._dtype, is_bias=False, ) self.qkv_bias = self.create_parameter( - shape=[3, num_heads, self.head_dim], + shape=qkv_bias_shape, attr=qkv_bias_attr, dtype=self._dtype, is_bias=True, ) self.linear_weight = self.create_parameter( - shape=[num_heads * self.head_dim, embed_dim], + shape=[self.num_heads * self.head_dim, embed_dim], attr=linear_weight_attr, dtype=self._dtype, is_bias=False, @@ -436,6 +451,8 @@ class FusedMultiHeadAttention(Layer): ln_epsilon=self._epsilon, training=self.training, ring_id=self._ring_id, + num_heads=self.num_heads, + transpose_qkv_wb=self.transpose_qkv_wb, name=self.name, ) return out -- GitLab