未验证 提交 e44ff495 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass mp support (#50320)

上级 a7539508
......@@ -24,7 +24,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
bool add_residual,
bool use_mp) {
// pre layer norm pattern
PDNode* pre_layer_norm_out_node{nullptr};
if (pre_layer_norm) {
......@@ -51,6 +52,28 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
pre_layer_norm_variance_node});
}
// c_identity for mp
PDNode* c_identity_input_node = pre_layer_norm ? pre_layer_norm_out_node : x;
PDNode* c_identity_out_node{nullptr};
if (use_mp) {
auto* c_identity_node =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
if (pre_layer_norm) {
c_identity_input_node->assert_is_op_input("c_identity", "X");
}
c_identity_out_node = pattern->NewNode(c_identity_out_repr())
->assert_is_op_output("c_identity");
c_identity_node->LinksFrom({c_identity_input_node})
.LinksTo({c_identity_out_node});
}
PDNode* fuse_qkv_input_node = x;
if (use_mp) {
fuse_qkv_input_node = c_identity_out_node;
} else if (pre_layer_norm) {
fuse_qkv_input_node = pre_layer_norm_out_node;
}
// fuse qkv pattern
auto* fuse_qkv_matmul_node =
pattern->NewNode(fuse_qkv_matmul_op_repr())->assert_is_op("matmul_v2");
......@@ -58,15 +81,11 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->assert_is_op_input("matmul_v2", "Y");
auto* fuse_qkv_matmul_out_node = pattern->NewNode(fuse_qkv_matmul_out_repr())
->assert_is_op_output("matmul_v2");
if (pre_layer_norm) {
pre_layer_norm_out_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_matmul_node
->LinksFrom({pre_layer_norm_out_node, fuse_qkv_matmul_w_node})
.LinksTo({fuse_qkv_matmul_out_node});
} else {
fuse_qkv_matmul_node->LinksFrom({x, fuse_qkv_matmul_w_node})
.LinksTo({fuse_qkv_matmul_out_node});
if (pre_layer_norm || use_mp) {
fuse_qkv_input_node->assert_is_op_input("matmul_v2", "X");
}
fuse_qkv_matmul_node->LinksFrom({fuse_qkv_input_node, fuse_qkv_matmul_w_node})
.LinksTo({fuse_qkv_matmul_out_node});
auto* fuse_qkv_ele_add_node = pattern->NewNode(fuse_qkv_ele_add_op_repr())
->assert_is_op("elementwise_add");
......@@ -246,6 +265,20 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->LinksFrom({out_linear_matmul_out_node, out_linear_ele_add_bias_node})
.LinksTo({out_linear_ele_add_out_node});
PDNode* mp_allreduce_out_node{nullptr};
if (use_mp) {
mp_allreduce_out_node = pattern->NewNode(mp_allreudce_sum_out_repr())
->assert_is_op_output("mp_allreduce_sum");
auto* mp_allreduce_node = pattern->NewNode(mp_allreudce_sum_op_repr())
->assert_is_op("mp_allreduce_sum");
out_linear_ele_add_out_node->assert_is_op_input("mp_allreduce_sum");
mp_allreduce_node->LinksFrom({out_linear_ele_add_out_node})
.LinksTo({mp_allreduce_out_node});
}
PDNode* out_linear_dropout_input_node =
use_mp ? mp_allreduce_out_node : out_linear_ele_add_out_node;
auto* out_linear_dropout_node =
pattern->NewNode(out_linear_dropout_op_repr())->assert_is_op("dropout");
auto* out_linear_dropout_mask_node =
......@@ -254,8 +287,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
auto* out_linear_dropout_out_node =
pattern->NewNode(out_linear_dropout_out_repr())
->assert_is_op_output("dropout");
out_linear_ele_add_out_node->assert_is_op_input("dropout", "X");
out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node})
out_linear_dropout_input_node->assert_is_op_input("dropout", "X");
out_linear_dropout_node->LinksFrom({out_linear_dropout_input_node})
.LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node});
if (!add_residual && pre_layer_norm) {
......@@ -324,7 +357,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
bool add_residual,
bool use_mp) {
// post layer norm
PDNode* post_layer_norm_grad_out_node{nullptr};
if (!pre_layer_norm) {
......@@ -424,6 +458,20 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
{out_linear_grad_input_node, out_linear_dropout_grad_mask_node})
.LinksTo({out_linear_dropout_grad_out_node});
PDNode* mp_c_identity_out_node{nullptr};
if (use_mp) {
mp_c_identity_out_node = pattern->NewNode(mp_allreudce_sum_grad_out_repr())
->assert_is_op_output("c_identity", "Out");
auto* mp_c_identity_node = pattern->NewNode(mp_allreudce_sum_grad_op_repr())
->assert_is_op("c_identity");
out_linear_dropout_grad_out_node->assert_is_op_input("c_identity");
mp_c_identity_node->LinksFrom({out_linear_dropout_grad_out_node})
.LinksTo({mp_c_identity_out_node});
}
PDNode* out_linear_ele_add_grad_input_node =
use_mp ? mp_c_identity_out_node : out_linear_dropout_grad_out_node;
auto* out_linear_ele_add_grad_node =
pattern->NewNode(out_linear_ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
......@@ -439,10 +487,10 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto* out_linear_ele_add_grad_bias_grad_node =
pattern->NewNode(out_linear_ele_add_grad_bias_grad_repr())
->assert_is_op_output("elementwise_add_grad", "Y@GRAD");
out_linear_dropout_grad_out_node->assert_is_op_input("elementwise_add_grad",
"Out@GRAD");
out_linear_ele_add_grad_input_node->assert_is_op_input("elementwise_add_grad",
"Out@GRAD");
out_linear_ele_add_grad_node
->LinksFrom({out_linear_dropout_grad_out_node,
->LinksFrom({out_linear_ele_add_grad_input_node,
out_linear_ele_add_grad_x_node,
out_linear_ele_add_grad_bias_node})
.LinksTo({out_linear_ele_add_grad_x_grad_node,
......@@ -699,54 +747,78 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
.LinksTo(
{fuse_qkv_matmul_grad_x_grad_node, fuse_qkv_matmul_grad_w_grad_node});
if (!pre_layer_norm) {
return fuse_qkv_matmul_grad_x_grad_node;
PDNode* mp_allreduce_out_node{nullptr};
if (use_mp) {
mp_allreduce_out_node = pattern->NewNode(c_identity_grad_out_repr())
->assert_is_op_output("c_allreduce_sum", "Out");
auto* mp_allreduce_node = pattern->NewNode(c_identity_grad_op_repr())
->assert_is_op("c_allreduce_sum");
fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("c_allreduce_sum",
"X");
mp_allreduce_node->LinksFrom({fuse_qkv_matmul_grad_x_grad_node})
.LinksTo({mp_allreduce_out_node});
}
PDNode* pre_layer_norm_input_node =
use_mp ? mp_allreduce_out_node : fuse_qkv_matmul_grad_x_grad_node;
if (!pre_layer_norm && !add_residual) {
return pre_layer_norm_input_node;
}
PDNode* pre_layer_norm_grad_x_grad_node{nullptr};
if (pre_layer_norm) {
// pre layer norm
auto* pre_layer_norm_grad_node =
pattern->NewNode(pre_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad");
auto* pre_layer_norm_grad_scale_node =
pattern->NewNode(pre_layer_norm_grad_scale_repr())
->assert_is_op_input("layer_norm_grad", "Scale");
auto* pre_layer_norm_grad_bias_node =
pattern->NewNode(pre_layer_norm_grad_bias_repr())
->assert_is_op_input("layer_norm_grad", "Bias");
auto* pre_layer_norm_grad_mean_node =
pattern->NewNode(pre_layer_norm_grad_mean_repr())
->assert_is_op_input("layer_norm_grad", "Mean");
auto* pre_layer_norm_grad_variance_node =
pattern->NewNode(pre_layer_norm_grad_variance_repr())
->assert_is_op_input("layer_norm_grad", "Variance");
auto* pre_layer_norm_grad_x_node =
add_residual ? residual_ele_add_grad_x_node
: pattern->NewNode(pre_layer_norm_grad_x_repr())
->assert_is_op_input("layer_norm_grad", "X");
auto* pre_layer_norm_grad_scale_grad_node =
pattern->NewNode(pre_layer_norm_grad_scale_grad_repr())
->assert_is_op_output("layer_norm_grad", "Scale@GRAD");
auto* pre_layer_norm_grad_bias_grad_node =
pattern->NewNode(pre_layer_norm_grad_bias_grad_repr())
->assert_is_op_output("layer_norm_grad", "Bias@GRAD");
pre_layer_norm_grad_x_grad_node =
pattern->NewNode(pre_layer_norm_grad_x_grad_repr())
->assert_is_op_output("layer_norm_grad", "X@GRAD");
pre_layer_norm_input_node->assert_is_op_input("layer_norm_grad", "Y@GRAD");
pre_layer_norm_grad_node
->LinksFrom({pre_layer_norm_input_node,
pre_layer_norm_grad_scale_node,
pre_layer_norm_grad_bias_node,
pre_layer_norm_grad_mean_node,
pre_layer_norm_grad_variance_node,
pre_layer_norm_grad_x_node})
.LinksTo({pre_layer_norm_grad_scale_grad_node,
pre_layer_norm_grad_bias_grad_node,
pre_layer_norm_grad_x_grad_node});
}
// pre layer norm
auto* pre_layer_norm_grad_node =
pattern->NewNode(pre_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad");
auto* pre_layer_norm_grad_scale_node =
pattern->NewNode(pre_layer_norm_grad_scale_repr())
->assert_is_op_input("layer_norm_grad", "Scale");
auto* pre_layer_norm_grad_bias_node =
pattern->NewNode(pre_layer_norm_grad_bias_repr())
->assert_is_op_input("layer_norm_grad", "Bias");
auto* pre_layer_norm_grad_mean_node =
pattern->NewNode(pre_layer_norm_grad_mean_repr())
->assert_is_op_input("layer_norm_grad", "Mean");
auto* pre_layer_norm_grad_variance_node =
pattern->NewNode(pre_layer_norm_grad_variance_repr())
->assert_is_op_input("layer_norm_grad", "Variance");
auto* pre_layer_norm_grad_x_node =
add_residual ? residual_ele_add_grad_x_node
: pattern->NewNode(pre_layer_norm_grad_x_repr())
->assert_is_op_input("layer_norm_grad", "X");
auto* pre_layer_norm_grad_scale_grad_node =
pattern->NewNode(pre_layer_norm_grad_scale_grad_repr())
->assert_is_op_output("layer_norm_grad", "Scale@GRAD");
auto* pre_layer_norm_grad_bias_grad_node =
pattern->NewNode(pre_layer_norm_grad_bias_grad_repr())
->assert_is_op_output("layer_norm_grad", "Bias@GRAD");
auto* pre_layer_norm_grad_x_grad_node =
pattern->NewNode(pre_layer_norm_grad_x_grad_repr())
->assert_is_op_output("layer_norm_grad", "X@GRAD");
fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("layer_norm_grad",
"Y@GRAD");
pre_layer_norm_grad_node
->LinksFrom({fuse_qkv_matmul_grad_x_grad_node,
pre_layer_norm_grad_scale_node,
pre_layer_norm_grad_bias_node,
pre_layer_norm_grad_mean_node,
pre_layer_norm_grad_variance_node,
pre_layer_norm_grad_x_node})
.LinksTo({pre_layer_norm_grad_scale_grad_node,
pre_layer_norm_grad_bias_grad_node,
pre_layer_norm_grad_x_grad_node});
PDNode* grad_accumulation_x_input_node = fuse_qkv_matmul_grad_x_grad_node;
if (pre_layer_norm) {
grad_accumulation_x_input_node = pre_layer_norm_grad_x_grad_node;
} else if (use_mp) {
grad_accumulation_x_input_node = mp_allreduce_out_node;
}
if (!add_residual) {
return pre_layer_norm_grad_x_grad_node;
return grad_accumulation_x_input_node;
}
auto* grad_accumulation_sum_node =
......@@ -754,9 +826,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto* grad_accumulation_sum_out_node =
pattern->NewNode(grad_accumulation_out_repr())
->assert_is_op_output("sum");
residual_ele_add_grad_x_grad_node->assert_is_op_input("sum");
grad_accumulation_x_input_node->assert_is_op_input("sum");
grad_accumulation_sum_node
->LinksFrom(
{pre_layer_norm_grad_x_grad_node, residual_ele_add_grad_x_grad_node})
{grad_accumulation_x_input_node, residual_ele_add_grad_x_grad_node})
.LinksTo({grad_accumulation_sum_out_node});
return grad_accumulation_sum_out_node;
......@@ -771,10 +845,64 @@ void FusedAttentionsPass::ApplyImpl(Graph* graph) const {
graph = PreMaskDropResFwd(graph, &cache);
graph = PreMaskDropResBwd(graph, &cache);
cache.ResetCache();
graph = PreMaskDropResMPFwd(graph, &cache);
graph = PreMaskDropResMPBwd(graph, &cache);
cache.ResetCache();
}
ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
Graph* graph, FusedAttentionPassCache* cache) const {
return ForwardHandlerHelper(graph,
cache,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true,
/* use_mp */ false);
}
ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
Graph* graph, FusedAttentionPassCache* cache) const {
return BackwardHandlerHelper(graph,
cache,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true,
/* use_mp */ false);
}
ir::Graph* FusedAttentionsPass::PreMaskDropResMPFwd(
Graph* graph, FusedAttentionPassCache* cache) const {
return ForwardHandlerHelper(graph,
cache,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true,
/* use_mp */ true);
}
ir::Graph* FusedAttentionsPass::PreMaskDropResMPBwd(
Graph* graph, FusedAttentionPassCache* cache) const {
return BackwardHandlerHelper(graph,
cache,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true,
/* use_mp */ true);
}
ir::Graph* FusedAttentionsPass::ForwardHandlerHelper(
Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
......@@ -783,11 +911,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
patterns::FusedAttentionPattern fused_attention_pattern(
gpd.mutable_pattern(), "fused_attention_pattern");
fused_attention_pattern(x,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
fused_attention_pattern(
x, pre_layer_norm, has_attn_mask, do_dropout, add_residual, use_mp);
int found_fused_attention = 0;
......@@ -840,10 +965,44 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern);
std::unordered_set<const Node*> remove_nodes = {pre_layer_norm_op_node,
fuse_qkv_matmul_op_node,
fuse_qkv_ele_add_op_node,
fuse_qkv_reshape_op_node,
fuse_qkv_transpose_op_node,
fuse_qkv_split_op_node,
qk_matmul_op_node,
qk_scale_op_node,
add_mask_ele_add_op_node,
qk_softmax_op_node,
attn_dropout_op_node,
qkv_matmul_op_node,
qkv_transpose_op_node,
qkv_reshape_op_node,
out_linear_matmul_op_node,
out_linear_ele_add_op_node,
out_linear_dropout_op_node,
residual_ele_add_op_node};
int ring_id = -1;
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(
c_identity_op_node, c_identity_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mp_allreudce_sum_op_node,
mp_allreudce_sum_op,
fused_attention_pattern);
remove_nodes.insert(c_identity_op_node);
remove_nodes.insert(mp_allreudce_sum_op_node);
ring_id = PADDLE_GET_CONST(
int, mp_allreudce_sum_op_node->Op()->GetAttr("ring_id"));
}
std::string cache_anchor_name = fuse_qkv_matmul_w_node->Var()->Name();
OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block());
fused_attention_op_desc.SetType("fused_attention");
fused_attention_op_desc.SetAttr("ring_id", ring_id);
fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()});
cache->InsertIntoCache(GenerateCacheKey(cache_anchor_name, "X", block_id),
subgraph.at(x));
......@@ -1090,25 +1249,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
IR_NODE_LINK_TO(fused_attention_node, out_linear_dropout_mask_node);
IR_NODE_LINK_TO(fused_attention_node, residual_ele_add_out_node);
GraphSafeRemoveNodes(g,
{pre_layer_norm_op_node,
fuse_qkv_matmul_op_node,
fuse_qkv_ele_add_op_node,
fuse_qkv_reshape_op_node,
fuse_qkv_transpose_op_node,
fuse_qkv_split_op_node,
qk_matmul_op_node,
qk_scale_op_node,
add_mask_ele_add_op_node,
qk_softmax_op_node,
attn_dropout_op_node,
qkv_matmul_op_node,
qkv_transpose_op_node,
qkv_reshape_op_node,
out_linear_matmul_op_node,
out_linear_ele_add_op_node,
out_linear_dropout_op_node,
residual_ele_add_op_node});
GraphSafeRemoveNodes(g, remove_nodes);
found_fused_attention++;
};
......@@ -1119,8 +1260,14 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
return graph;
}
ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
Graph* graph, FusedAttentionPassCache* cache) const {
ir::Graph* FusedAttentionsPass::BackwardHandlerHelper(
Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
......@@ -1129,11 +1276,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
patterns::FusedAttentionGradPattern fused_attention_grad_pattern(
gpd.mutable_pattern(), "fused_attention_grad_pattern");
fused_attention_grad_pattern(x,
/* pre_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
fused_attention_grad_pattern(
x, pre_layer_norm, has_attn_mask, do_dropout, add_residual, use_mp);
int found_fused_attention = 0;
......@@ -1200,6 +1344,41 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
grad_accumulation_sum_op,
fused_attention_grad_pattern);
std::unordered_set<const Node*> remove_nodes = {
residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node,
out_linear_ele_add_grad_op_node,
out_linear_matmul_grad_op_node,
qkv_reshape_grad_op_node,
qkv_transpose_grad_op_node,
qkv_matmul_grad_op_node,
attn_dropout_grad_op_node,
qk_softmax_grad_op_node,
add_mask_ele_add_grad_op_node,
qk_scale_grad_op_node,
qk_matmul_grad_op_node,
fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node,
fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node,
fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node,
grad_accumulation_sum_op_node};
int ring_id = -1;
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(mp_allreudce_sum_grad_op_node,
mp_allreudce_sum_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity_grad_op_node,
c_identity_grad_op,
fused_attention_grad_pattern);
remove_nodes.insert(mp_allreudce_sum_grad_op_node);
remove_nodes.insert(c_identity_grad_op_node);
ring_id = PADDLE_GET_CONST(
int, mp_allreudce_sum_grad_op_node->Op()->GetAttr("ring_id"));
}
OpDesc fused_attention_grad_op_desc(
residual_ele_add_grad_op_node->Op()->Block());
fused_attention_grad_op_desc.SetType("fused_attention_grad");
......@@ -1347,7 +1526,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
"ln_epsilon",
PADDLE_GET_CONST(
float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon")));
fused_attention_grad_op_desc.SetAttr("ring_id", -1);
fused_attention_grad_op_desc.SetAttr("ring_id", ring_id);
GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_x_grad_node,
qkv_matmul_grad_x_grad,
......@@ -1497,26 +1676,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
IR_NODE_LINK_TO(src_mask_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(transpose_out_2_node, fused_attention_grad_node);
GraphSafeRemoveNodes(g,
{residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node,
out_linear_ele_add_grad_op_node,
out_linear_matmul_grad_op_node,
qkv_reshape_grad_op_node,
qkv_transpose_grad_op_node,
qkv_matmul_grad_op_node,
attn_dropout_grad_op_node,
qk_softmax_grad_op_node,
add_mask_ele_add_grad_op_node,
qk_scale_grad_op_node,
qk_matmul_grad_op_node,
fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node,
fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node,
fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node,
grad_accumulation_sum_op_node});
GraphSafeRemoveNodes(g, remove_nodes);
found_fused_attention++;
};
......
......@@ -33,6 +33,7 @@ namespace patterns {
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
// 5. Use model tensor parallel or not.
struct FusedAttentionPattern : public PatternBase {
FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
......@@ -41,7 +42,8 @@ struct FusedAttentionPattern : public PatternBase {
bool pre_layer_norm, // do pre ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool add_residual, // add residual to out linear or not
bool use_mp); // use tensor parallel or not
// pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op);
......@@ -51,6 +53,10 @@ struct FusedAttentionPattern : public PatternBase {
PATTERN_DECL_NODE(pre_layer_norm_mean);
PATTERN_DECL_NODE(pre_layer_norm_variance);
// c_identity for mp
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
// fuse qkv projection
PATTERN_DECL_NODE(fuse_qkv_matmul_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_w);
......@@ -111,6 +117,10 @@ struct FusedAttentionPattern : public PatternBase {
PATTERN_DECL_NODE(out_linear_ele_add_bias);
PATTERN_DECL_NODE(out_linear_ele_add_out);
// allreudce for mp
PATTERN_DECL_NODE(mp_allreudce_sum_op);
PATTERN_DECL_NODE(mp_allreudce_sum_out);
PATTERN_DECL_NODE(out_linear_dropout_op);
PATTERN_DECL_NODE(out_linear_dropout_out);
PATTERN_DECL_NODE(out_linear_dropout_mask);
......@@ -131,13 +141,14 @@ struct FusedAttentionPattern : public PatternBase {
// Declare the grad pattern for multi head attention
struct FusedAttentionGradPattern : public PatternBase {
FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
: PatternBase(pattern, name_scope, "fused_attention_grad_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool add_residual, // add residual to out linear or not
bool use_mp); // use tensor parallel or not
// post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
......@@ -162,6 +173,10 @@ struct FusedAttentionGradPattern : public PatternBase {
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
// c_identity for mp
PATTERN_DECL_NODE(mp_allreudce_sum_grad_op); // c_identity
PATTERN_DECL_NODE(mp_allreudce_sum_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
......@@ -235,6 +250,10 @@ struct FusedAttentionGradPattern : public PatternBase {
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// allreduce for mp
PATTERN_DECL_NODE(c_identity_grad_op); // mp_allreduce_sum
PATTERN_DECL_NODE(c_identity_grad_out);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
......@@ -296,6 +315,7 @@ class FusedAttentionsPass : public FusePassBase {
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// 7. Use tensor model parallel? [MP]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
......@@ -305,6 +325,28 @@ class FusedAttentionsPass : public FusePassBase {
ir::Graph* PreMaskDropResBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResMPFwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResMPBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* ForwardHandlerHelper(Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const;
ir::Graph* BackwardHandlerHelper(Graph* graph,
FusedAttentionPassCache* cache,
bool pre_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual,
bool use_mp) const;
const std::string GenerateCacheKey(const std::string anchor,
const std::string var_name,
int block_id) const {
......
......@@ -120,6 +120,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
auto y_dim = ctx->GetInputDim("QKVW");
int dim_head;
int hidden_size;
int nranks = 1;
if (transpose_qkv_wb) {
PADDLE_ENFORCE_EQ(y_dim.size(),
2,
......@@ -149,8 +150,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 2"
"(dim_embed, 3 * dim_embed)."));
} else {
// compute the mp nranks
nranks = (y_dim[0] * 3) / y_dim[1];
}
dim_head = y_dim[0] / num_heads;
dim_head = y_dim[0] / (num_heads * nranks);
hidden_size = y_dim[0];
} else {
PADDLE_ENFORCE_EQ(y_dim.size(),
......@@ -210,11 +214,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
}
if (transpose_qkv_wb) {
// [batch_size, seq_len, 3 * hidden_size]
ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size});
// [batch_size, seq_len, 3 * num_heads * dim_head]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], 3 * num_heads * dim_head});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], 3 * num_heads * dim_head});
}
} else {
// [batch_size, seq_len, 3, num_head, head_size]
......
......@@ -217,13 +217,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
int num_head;
int dim_head;
int nranks = 1;
// get num_head and dim_head in two different ways
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / num_head;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
......@@ -579,12 +581,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / num_head;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
......
......@@ -908,3 +908,15 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_dygraph_save_for_auto_infer
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU)
bash_test_modules(
test_fused_attention_pass_with_mp
START_BASH
test_fused_attention_pass_with_mp.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=")
set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT
"120")
endif()
# Copyright (c) 2013 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.fluid as fluid
import paddle.nn.functional as F
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class MultiHeadAttentionWithMP(paddle.nn.Layer):
def __init__(
self,
embed_dim,
num_heads,
add_residual=True,
pre_ln=True,
attn_dropout=True,
):
super(MultiHeadAttentionWithMP, self).__init__()
self.embed_dim = embed_dim
self.kdim = embed_dim
self.vdim = embed_dim
self.num_heads = num_heads
self.add_residual = add_residual
self.pre_ln = pre_ln
self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert num_heads % 2 == 0
self.num_heads = num_heads // 2
self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.qkv_proj = paddle.nn.Linear(
embed_dim, 3 * self.num_heads * self.head_dim
)
self.out_proj = paddle.nn.Linear(
self.num_heads * self.head_dim, embed_dim
)
self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")
def forward(self, x, attn_mask=None):
residual = x
if self.pre_ln:
# pre layer norm
x = self.norm1(x)
x = paddle.distributed.collective._c_identity(x)
# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)
# compute core attention
q = paddle.scale(q, scale=self.head_dim**-0.5)
product = paddle.matmul(x=q, y=k, transpose_y=True)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.attn_dropout:
weights = F.dropout(
weights, 0.1, training=self.training, mode="upscale_in_train"
)
out = paddle.matmul(weights, v)
out = paddle.transpose(out, perm=[0, 2, 1, 3])
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
out = paddle.distributed.collective._mp_allreduce(
out, use_calc_stream=True, use_model_parallel=True
)
out = self.dropout(out)
if self.add_residual:
out = residual + out
if not self.pre_ln:
# post layer norm
out = self.norm2(out)
return out
class TestFusedAttentionPassWithMP(unittest.TestCase):
def setUp(self):
fleet.init()
self.endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')
self.current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
self.nranks = len(self.endpoints)
self.rank = self.endpoints.index(self.current_endpoint)
self.gpu_id = int(os.getenv("FLAGS_selected_gpus"))
self.place = fluid.CUDAPlace(self.gpu_id)
self.exe = fluid.Executor(self.place)
self.endpoints.remove(self.current_endpoint)
self.other_endpoints = self.endpoints
self.add_residual = True
self.pre_ln = True
self.attn_dropout = True
self.add_mask = True
self.x_data = None
self.mask_data = None
def get_rst(self, use_pass=False):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12
np.random.seed(1234)
if self.x_data is None:
self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
'float32'
)
self.mask_data = np.random.rand(
batch_size, num_heads // 2, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[-1, seq_len, seq_len],
dtype='float32',
)
if self.add_mask:
attn_mask = paddle.static.data(
name="attn_mask",
shape=[-1, num_heads // 2, seq_len, seq_len],
dtype='float32',
)
else:
attn_mask = None
data_linear = paddle.nn.Linear(seq_len, hidden_size)
multi_head_attn = MultiHeadAttentionWithMP(
hidden_size,
num_heads,
add_residual=self.add_residual,
pre_ln=self.pre_ln,
attn_dropout=self.attn_dropout,
)
attn_input = data_linear(data)
out = multi_head_attn(attn_input, attn_mask)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
startup_block = startup_prog.global_block()
nccl_id_var = startup_block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW,
)
startup_block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': self.rank,
'endpoint': self.current_endpoint,
'other_endpoints': self.other_endpoints,
},
)
startup_block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': self.nranks,
'rank': self.rank,
'ring_id': 0,
'device_id': self.gpu_id,
},
)
if use_pass:
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'
self.exe.run(startup_prog)
for i in range(2):
rst = self.exe.run(
main_prog,
feed={'x': self.x_data, 'attn_mask': self.mask_data},
fetch_list=[loss],
)
return rst
def test_pass(self):
fused_rst = self.get_rst(use_pass=True)
non_fused_rst = self.get_rst()
assert np.allclose(fused_rst, non_fused_rst, atol=1e-5)
if __name__ == "__main__":
unittest.main()
#!/bin/bash
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# use default values
# FIXME: random fails on Unknown command lines -c (or -m).
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch fused_attention_pass_with_mp.py
......@@ -58,6 +58,7 @@ test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU;ASCEND;ASCEND_CL,,,test_
test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_new_group,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=,
test_c_comm_init_op,LINUX,GPU;XPU;ASCEND;ASCEND_CL,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=,
test_fused_attention_pass_with_mp,LINUX,GPU;;;,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=,
test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册