未验证 提交 e44ff495 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass mp support (#50320)

上级 a7539508
......@@ -33,6 +33,7 @@ namespace patterns {
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
// 5. Use model tensor parallel or not.
struct FusedAttentionPattern : public PatternBase {
FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
......@@ -41,7 +42,8 @@ struct FusedAttentionPattern : public PatternBase {
bool pre_layer_norm, // do pre ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool add_residual, // add residual to out linear or not
bool use_mp); // use tensor parallel or not
// pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op);
......@@ -51,6 +53,10 @@ struct FusedAttentionPattern : public PatternBase {
PATTERN_DECL_NODE(pre_layer_norm_mean);
PATTERN_DECL_NODE(pre_layer_norm_variance);
// c_identity for mp
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
// fuse qkv projection
PATTERN_DECL_NODE(fuse_qkv_matmul_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_w);
......@@ -111,6 +117,10 @@ struct FusedAttentionPattern : public PatternBase {
PATTERN_DECL_NODE(out_linear_ele_add_bias);
PATTERN_DECL_NODE(out_linear_ele_add_out);
// allreudce for mp
PATTERN_DECL_NODE(mp_allreudce_sum_op);
PATTERN_DECL_NODE(mp_allreudce_sum_out);
PATTERN_DECL_NODE(out_linear_dropout_op);
PATTERN_DECL_NODE(out_linear_dropout_out);
PATTERN_DECL_NODE(out_linear_dropout_mask);
......@@ -131,13 +141,14 @@ struct FusedAttentionPattern : public PatternBase {
// Declare the grad pattern for multi head attention
struct FusedAttentionGradPattern : public PatternBase {
FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
: PatternBase(pattern, name_scope, "fused_attention_grad_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool add_residual, // add residual to out linear or not
bool use_mp); // use tensor parallel or not
// post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
......@@ -162,6 +173,10 @@ struct FusedAttentionGradPattern : public PatternBase {
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
// c_identity for mp
PATTERN_DECL_NODE(mp_allreudce_sum_grad_op); // c_identity
PATTERN_DECL_NODE(mp_allreudce_sum_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
......@@ -235,6 +250,10 @@ struct FusedAttentionGradPattern : public PatternBase {
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// allreduce for mp
PATTERN_DECL_NODE(c_identity_grad_op); // mp_allreduce_sum
PATTERN_DECL_NODE(c_identity_grad_out);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
......@@ -296,6 +315,7 @@ class FusedAttentionsPass : public FusePassBase {
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// 7. Use tensor model parallel? [MP]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
......@@ -305,6 +325,28 @@ class FusedAttentionsPass : public FusePassBase {
ir::Graph* PreMaskDropResBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResMPFwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResMPBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* ForwardHandlerHelper(Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const;
ir::Graph* BackwardHandlerHelper(Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const;
const std::string GenerateCacheKey(const std::string anchor,
const std::string var_name,
int block_id) const {
......
......@@ -120,6 +120,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
auto y_dim = ctx->GetInputDim("QKVW");
int dim_head;
int hidden_size;
int nranks = 1;
if (transpose_qkv_wb) {
PADDLE_ENFORCE_EQ(y_dim.size(),
2,
......@@ -149,8 +150,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2"
"(dim_embed, 3 * dim_embed)."));
} else {
// compute the mp nranks
nranks = (y_dim[0] * 3) / y_dim[1];
}
dim_head = y_dim[0] / num_heads;
dim_head = y_dim[0] / (num_heads * nranks);
hidden_size = y_dim[0];
} else {
PADDLE_ENFORCE_EQ(y_dim.size(),
......@@ -210,11 +214,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
}
if (transpose_qkv_wb) {
// [batch_size, seq_len, 3 * hidden_size]
ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size});
// [batch_size, seq_len, 3 * num_heads * dim_head]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], 3 * num_heads * dim_head});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], 3 * num_heads * dim_head});
}
} else {
// [batch_size, seq_len, 3, num_head, head_size]
......
......@@ -217,13 +217,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
int num_head;
int dim_head;
int nranks = 1;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / num_head;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
......@@ -579,12 +581,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / num_head;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
......
......@@ -908,3 +908,15 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_dygraph_save_for_auto_infer
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU)
bash_test_modules(
test_fused_attention_pass_with_mp
START_BASH
test_fused_attention_pass_with_mp.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=")
set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT
"120")
endif()
# Copyright (c) 2013 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.fluid as fluid
import paddle.nn.functional as F
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class MultiHeadAttentionWithMP(paddle.nn.Layer):
def __init__(
self,
embed_dim,
num_heads,
add_residual=True,
pre_ln=True,
attn_dropout=True,
):
super(MultiHeadAttentionWithMP, self).__init__()
self.embed_dim = embed_dim
self.kdim = embed_dim
self.vdim = embed_dim
self.num_heads = num_heads
self.add_residual = add_residual
self.pre_ln = pre_ln
self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert num_heads % 2 == 0
self.num_heads = num_heads // 2
self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.qkv_proj = paddle.nn.Linear(
embed_dim, 3 * self.num_heads * self.head_dim
)
self.out_proj = paddle.nn.Linear(
self.num_heads * self.head_dim, embed_dim
)
self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")
def forward(self, x, attn_mask=None):
residual = x
if self.pre_ln:
# pre layer norm
x = self.norm1(x)
x = paddle.distributed.collective._c_identity(x)
# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)
# compute core attention
q = paddle.scale(q, scale=self.head_dim**-0.5)
product = paddle.matmul(x=q, y=k, transpose_y=True)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.attn_dropout:
weights = F.dropout(
weights, 0.1, training=self.training, mode="upscale_in_train"
)
out = paddle.matmul(weights, v)
out = paddle.transpose(out, perm=[0, 2, 1, 3])
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
out = paddle.distributed.collective._mp_allreduce(
out, use_calc_stream=True, use_model_parallel=True
)
out = self.dropout(out)
if self.add_residual:
out = residual + out
if not self.pre_ln:
# post layer norm
out = self.norm2(out)
return out
class TestFusedAttentionPassWithMP(unittest.TestCase):
def setUp(self):
fleet.init()
self.endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')
self.current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
self.nranks = len(self.endpoints)
self.rank = self.endpoints.index(self.current_endpoint)
self.gpu_id = int(os.getenv("FLAGS_selected_gpus"))
self.place = fluid.CUDAPlace(self.gpu_id)
self.exe = fluid.Executor(self.place)
self.endpoints.remove(self.current_endpoint)
self.other_endpoints = self.endpoints
self.add_residual = True
self.pre_ln = True
self.attn_dropout = True
self.add_mask = True
self.x_data = None
self.mask_data = None
def get_rst(self, use_pass=False):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12
np.random.seed(1234)
if self.x_data is None:
self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
'float32'
)
self.mask_data = np.random.rand(
batch_size, num_heads // 2, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[-1, seq_len, seq_len],
dtype='float32',
)
if self.add_mask:
attn_mask = paddle.static.data(
name="attn_mask",
shape=[-1, num_heads // 2, seq_len, seq_len],
dtype='float32',
)
else:
attn_mask = None
data_linear = paddle.nn.Linear(seq_len, hidden_size)
multi_head_attn = MultiHeadAttentionWithMP(
hidden_size,
num_heads,
add_residual=self.add_residual,
pre_ln=self.pre_ln,
attn_dropout=self.attn_dropout,
)
attn_input = data_linear(data)
out = multi_head_attn(attn_input, attn_mask)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
startup_block = startup_prog.global_block()
nccl_id_var = startup_block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW,
)
startup_block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': self.rank,
'endpoint': self.current_endpoint,
'other_endpoints': self.other_endpoints,
},
)
startup_block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': self.nranks,
'rank': self.rank,
'ring_id': 0,
'device_id': self.gpu_id,
},
)
if use_pass:
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'
self.exe.run(startup_prog)
for i in range(2):
rst = self.exe.run(
main_prog,
feed={'x': self.x_data, 'attn_mask': self.mask_data},
fetch_list=[loss],
)
return rst
def test_pass(self):
fused_rst = self.get_rst(use_pass=True)
non_fused_rst = self.get_rst()
assert np.allclose(fused_rst, non_fused_rst, atol=1e-5)
if __name__ == "__main__":
unittest.main()
#!/bin/bash
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# use default values
# FIXME: random fails on Unknown command lines -c (or -m).
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch fused_attention_pass_with_mp.py
......@@ -58,6 +58,7 @@ test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU;ASCEND;ASCEND_CL,,,test_
test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_new_group,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=,
test_c_comm_init_op,LINUX,GPU;XPU;ASCEND;ASCEND_CL,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=,
test_fused_attention_pass_with_mp,LINUX,GPU;;;,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=,
test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册