未验证 提交 ec857b85 编写于 作者: Y Yuang Liu 提交者: GitHub

Add transpose_qkv_wb flags to the fused_attention_op. (#49494)

上级 11f5848b
......@@ -108,10 +108,77 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
int num_heads = ctx->Attrs().Get<int>("num_heads");
bool transpose_qkv_wb = ctx->Attrs().Get<bool>("transpose_qkv_wb");
// x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
// if transpose_qkv_wb is True
// y: qkv's weight: [dim_embed, 3 * dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("QKVW");
int dim_head;
int hidden_size;
if (transpose_qkv_wb) {
PADDLE_ENFORCE_EQ(y_dim.size(),
2,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2 if enable"
"transpose_qkv_wb: (dim_embed, 3 * dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_GT(num_heads,
0,
platform::errors::InvalidArgument(
"The num_heads must be provided and greater than 0 "
"if enable transpose_qkv_wb, but we got %d.",
num_heads));
PADDLE_ENFORCE_EQ(y_dim[0] % num_heads,
0,
platform::errors::InvalidArgument(
"First dim of qkv_w must be divisible by num heads "
"if enable transpose_qkv_wb, but receive first "
"dim of qkv_w is %d and num_heads is %d.",
y_dim[0],
num_heads));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[0] * 3,
y_dim[1],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2"
"(dim_embed, 3 * dim_embed)."));
}
dim_head = y_dim[0] / num_heads;
hidden_size = y_dim[0];
} else {
PADDLE_ENFORCE_EQ(y_dim.size(),
4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4 if not"
"enable transpose_qkv_wb: (3, num_head, dim_head, "
"dim_embed), but received [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(y_dim[0],
3,
platform::errors::InvalidArgument(
"First dim of qkv_w must be 3 if disable "
"transpose_qkv_wb, but we got %d.",
y_dim[0]));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
num_heads = y_dim[1];
dim_head = y_dim[2];
hidden_size = y_dim[3];
}
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
......@@ -120,34 +187,18 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(),
4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2],
y_dim[3],
hidden_size,
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
"ShapeError: the dimension of x_dim[2] and y_dim[3] "
"(y_dim[1] if enable transpose_qkv_w) "
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
......@@ -157,17 +208,27 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
if (transpose_qkv_wb) {
// [batch_size, seq_len, 3 * hidden_size]
ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size});
}
} else {
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3, num_heads, dim_head});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], 3, num_heads, dim_head});
}
}
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
{3, x_dim[0], num_heads, x_dim[1], dim_head});
// cache_seq_len + seq_len if cache else seq_len
auto out_seq_len = x_dim[1];
......@@ -193,11 +254,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
y_dim[1],
num_heads,
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1],
num_heads,
c_dim[2])); // num_head
// In compile stage, input seq_len can be -1, in that case
// c_dim[3] may < 0 in while
......@@ -209,12 +270,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
}
PADDLE_ENFORCE_EQ(c_dim[4],
y_dim[2],
dim_head,
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
y_dim[2],
dim_head,
c_dim[4])); // head_size
out_seq_len += c_dim[3];
......@@ -224,25 +286,26 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
}
// [batch, num_head, seq_len, out_seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
ctx->SetOutputDim("QKOut", {x_dim[0], num_heads, x_dim[1], out_seq_len});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
{x_dim[0], num_heads, x_dim[1], out_seq_len});
}
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
{x_dim[0], num_heads, x_dim[1], out_seq_len});
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("AttnDropoutMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
{x_dim[0], num_heads, x_dim[1], out_seq_len});
}
ctx->SetOutputDim("SoftmaxOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
{x_dim[0], num_heads, x_dim[1], out_seq_len});
// [batch_size, num_heads, seq_len, head_dim]
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
ctx->SetOutputDim("QKTVOut", {x_dim[0], num_heads, x_dim[1], dim_head});
// [batch_size, seq_len, number of heads*head size]
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], num_heads, dim_head});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<bool>("is_test") == false) {
......@@ -315,6 +378,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("CacheKVOut", "The udpated cache KV.");
AddOutput("Y", "Result after attention.");
AddAttr<int>("num_heads", "The number head for multi_head_attention.")
.SetDefault(-1);
AddAttr<bool>("transpose_qkv_wb",
"The qkv_w shape is (h, 3h), do transpose to it.")
.SetDefault(false);
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
......
......@@ -25,9 +25,12 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
......@@ -87,8 +90,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
const auto num_heads = ctx.Attr<int>("num_heads");
const auto transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
// x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
// if transpose_qkv_wb is True
// y: qkv's weight: [dim_embed, 3 * dim_embed]
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
......@@ -206,8 +215,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int num_head;
int dim_head;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
num_head = num_heads;
dim_head = dim_embed / num_head;
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
......@@ -222,9 +239,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true)
bool transB = transpose_qkv_wb ? false : true;
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
false,
true,
transB,
bsz_seq,
output_size,
input_size,
......@@ -288,6 +306,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(
qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
}
if (transpose_qkv_wb) {
// resize the output for fmha compute
qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out,
cache_kv,
......@@ -316,6 +341,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
fmha_out);
}
if (transpose_qkv_wb) {
// resize the output back to make the shape compatible with infer shape
qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
......@@ -374,6 +405,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const int num_heads = ctx.Attr<int>("num_heads");
const bool transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
......@@ -544,8 +577,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int num_head;
int dim_head;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
num_head = num_heads;
dim_head = dim_embed / num_head;
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
......@@ -562,7 +602,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
}
bool transA = false;
bool transB = true;
bool transB = transpose_qkv_wb ? false : true;
bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>(
ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
......@@ -655,6 +695,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight,
nullptr);
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize(
{batch_size, max_seq_len, 3, num_head, dim_head});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
}
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2,
has_attn_dropout ? src_mask : nullptr,
......@@ -691,6 +740,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_qkv_out);
}
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
}
if (pre_layer_norm) {
auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
......
......@@ -103,6 +103,7 @@ class TestFusedAttentionOp(OpTest):
self.query_length,
self.query_length,
)
self.transpose_qkv_wb = False
def generate_input_data(self):
self.query = np.random.rand(
......@@ -265,7 +266,8 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
)
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
if not self.transpose_qkv_wb:
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False
......@@ -276,15 +278,23 @@ class TestFusedAttentionOp(OpTest):
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
if not self.transpose_qkv_wb:
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
else:
q_proj_weight = q_proj_weight.numpy()
k_proj_weight = k_proj_weight.numpy()
v_proj_weight = v_proj_weight.numpy()
concatenate_axis = 1 if self.transpose_qkv_wb else 0
qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight)
)
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim)
(q_proj_weight, k_proj_weight, v_proj_weight), axis=concatenate_axis
)
if not self.transpose_qkv_wb:
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim)
)
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kv = None
......@@ -317,6 +327,8 @@ class TestFusedAttentionOp(OpTest):
self.dropout_prob,
self.attn_dropout_prob,
ln2_epsilon,
num_heads=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
)
if self.has_cache_kv:
......@@ -344,6 +356,19 @@ class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
self.bias_attr = False
class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionOp):
def config(self):
super().config()
self.transpose_qkv_wb = True
class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionOp):
def config(self):
super().config()
self.transpose_qkv_wb = True
self.bias_attr = False
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
super().config()
......
......@@ -81,6 +81,8 @@ def compute_reference(
qkv_bias,
out_linear_weight,
out_linear_bias,
num_head,
transpose_qkv_wb,
):
batch_size = query.shape[0]
seq_len = query.shape[1]
......@@ -93,19 +95,26 @@ def compute_reference(
if pre_layer_norm:
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
# embed_dim, 3, num_heads, self.head_dim
qkv_weight = qkv_weight.transpose((3, 0, 1, 2))
qkv_weight = qkv_weight.reshape(
qkv_weight.shape[0],
qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3],
)
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(
qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2]
head_dim = embed_dim // num_head
if not transpose_qkv_wb:
# embed_dim, 3, num_heads, self.head_dim
qkv_weight = qkv_weight.transpose((3, 0, 1, 2))
qkv_weight = qkv_weight.reshape(
qkv_weight.shape[0],
qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3],
)
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(
qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias.shape[2]
)
else:
assert len(qkv_weight.shape) == 2
assert qkv_weight.shape[0] * 3 == qkv_weight.shape[1]
if qkv_bias is not None:
assert len(qkv_bias.shape) == 1
assert qkv_bias.shape[0] == qkv_weight.shape[1]
if pre_layer_norm:
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
......@@ -189,6 +198,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.setPreLn()
self.setAttnMask()
self.setBiasAttr()
self.setTransposeWAndB()
self.config()
self.generate_input_data()
......@@ -209,6 +219,9 @@ class TestFusedAttentionAPI(unittest.TestCase):
def setBiasAttr(self):
self.bias_attr = None
def setTransposeWAndB(self):
self.transpose_qkv_wb = False
def setPreLn(self):
self.pre_layer_norm = False
......@@ -284,6 +297,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.bias_attr,
self.weight_attr,
self.bias_attr,
transpose_qkv_wb=self.transpose_qkv_wb,
)
if self.bias_attr is not False:
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
......@@ -323,6 +337,8 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(),
fused_attn_linear_bias,
num_head=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
)
np.testing.assert_allclose(
ref_out, out.numpy(), rtol=self.rtol, atol=self.atol
......@@ -346,6 +362,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.bias_attr,
self.weight_attr,
self.bias_attr,
transpose_qkv_wb=self.transpose_qkv_wb,
)
x = paddle.static.data(
......@@ -562,6 +579,8 @@ class TestFusedAttentionAPI(unittest.TestCase):
qkv_bias,
linear_weight,
linear_bias,
num_head=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
)
np.testing.assert_allclose(ref_out, out, rtol=self.rtol, atol=self.atol)
......@@ -583,5 +602,18 @@ class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
self.bias_attr = False
class TestFusedAttentionAPITransposeWAndB(TestFusedAttentionAPI):
def setTransposeWAndB(self):
self.transpose_qkv_wb = True
class TestFusedAttentionAPITransposeWAndBWithoutBias(TestFusedAttentionAPI):
def setTransposeWAndB(self):
self.transpose_qkv_wb = True
def setBiasAttr(self):
self.bias_attr = False
if __name__ == "__main__":
unittest.main()
......@@ -482,6 +482,8 @@ def fused_multi_head_attention(
mode='upscale_in_train',
ring_id=-1,
add_residual=True,
num_heads=-1,
transpose_qkv_wb=False,
name=None,
):
r"""
......@@ -567,6 +569,8 @@ def fused_multi_head_attention(
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional): Whether add residual at the end. Default is True.
num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -617,21 +621,48 @@ def fused_multi_head_attention(
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
assert (
len(qkv_weight.shape) == 4
), "The dims of the shape of qkv_weight should be 4."
assert (
qkv_weight.shape[0] == 3
), "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert (
qkv_weight.shape[3] == x.shape[2]
), "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
if ring_id == -1:
# under mp, the num head will be split, this equation will not hold
if not transpose_qkv_wb:
assert (
qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[3]
), "embed_dim must be divisible by num_heads."
len(qkv_weight.shape) == 4
), "The dims of the shape of qkv_weight should be 4."
assert (
qkv_weight.shape[0] == 3
), "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert (
qkv_weight.shape[3] == x.shape[2]
), "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
if ring_id == -1:
# under mp, the num head will be split, this equation will not hold
assert (
qkv_weight.shape[1] * qkv_weight.shape[2]
== qkv_weight.shape[3]
), "embed_dim must be divisible by num_heads."
else:
assert (
num_heads > 0
), "When enable transpose_qkv_wb, the num_heads should be provided and greater than 0."
assert len(qkv_weight.shape) == 2, (
"When enable transpose_qkv_wb, the dims of the shape of qkv_weight "
"should be 2 when enable transpose_qkv_wb."
)
if ring_id == -1:
# under mp, the num head will be split, this equation will not hold
assert qkv_weight.shape[1] == 3 * qkv_weight.shape[0], (
"When enable transpose_qkv_wb, the shape of qkv_weight should be "
"[embed_dim, 3 * embed_dim] when enable transpose_qkv_wb."
)
assert qkv_weight.shape[0] == x.shape[2], (
"When enable transpose_qkv_wb, the 1st dim of qkv_weight and 2nd dim of x "
"should be the same, i.e., embed_dim."
)
if qkv_bias is not None:
assert (
len(qkv_bias.shape) == 1
), "When enable transpose_qkv_wb, the dims of the shape of qkv_bias should be 1."
assert qkv_bias.shape[0] == qkv_weight.shape[1], (
"When enable transpose_qkv_wb, the 1st dim of qkv_bias and 2nd dim of "
"qkv_weight should be the same, i.e., embed_dim."
)
(
_,
_,
......@@ -665,6 +696,10 @@ def fused_multi_head_attention(
linear_bias,
ln_scale,
ln_bias,
'num_heads',
num_heads,
'transpose_qkv_wb',
transpose_qkv_wb,
'pre_layer_norm',
pre_layer_norm,
'epsilon',
......@@ -754,6 +789,8 @@ def fused_multi_head_attention(
'dropout_implementation': mode,
'add_residual': add_residual,
'ring_id': ring_id,
'num_heads': num_heads,
'transpose_qkv_wb': transpose_qkv_wb,
}
# set outputs
......
......@@ -246,6 +246,11 @@ class FusedMultiHeadAttention(Layer):
epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
transpose_qkv_wb (bool, optional): Support input qkv matmul weight shape as
[hidden_size, 3 * hidden_size] and qkv matmul bias shape as [3 * hidden_size].
Will transpose the weight to [3, num_head, head_dim, hidden_size] and transpose bias to
[3, num_head, hidden_size] in the fused_attention_op. Only support for GPU for now.
The default value is False, which is not do transpose to qkv_w and qkv_b.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
Examples:
......@@ -283,6 +288,7 @@ class FusedMultiHeadAttention(Layer):
epsilon=1e-5,
nranks=1,
ring_id=-1,
transpose_qkv_wb=False,
name=None,
):
super().__init__()
......@@ -315,22 +321,31 @@ class FusedMultiHeadAttention(Layer):
# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks
self.num_heads = num_heads // nranks
self.transpose_qkv_wb = transpose_qkv_wb
if self.transpose_qkv_wb:
# For tensor model parallel, use num_head * head_dim to compute the real shape.
qkv_wight_shape = [embed_dim, 3 * self.num_heads * self.head_dim]
qkv_bias_shape = [3 * self.num_heads * self.head_dim]
else:
qkv_wight_shape = [3, self.num_heads, self.head_dim, embed_dim]
qkv_bias_shape = [3, self.num_heads, self.head_dim]
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
shape=qkv_wight_shape,
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
shape=qkv_bias_shape,
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True,
)
self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
shape=[self.num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False,
......@@ -436,6 +451,8 @@ class FusedMultiHeadAttention(Layer):
ln_epsilon=self._epsilon,
training=self.training,
ring_id=self._ring_id,
num_heads=self.num_heads,
transpose_qkv_wb=self.transpose_qkv_wb,
name=self.name,
)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册