未验证 提交 7e8ef328 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass backward op replace. (#50186)

上级 f2ec69b4
...@@ -766,12 +766,15 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -766,12 +766,15 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void FusedAttentionsPass::ApplyImpl(Graph* graph) const { void FusedAttentionsPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
FusedAttentionPassCache cache;
graph = PreMaskDropResFwd(graph); graph = PreMaskDropResFwd(graph, &cache);
graph = PreMaskDropResBwd(graph); graph = PreMaskDropResBwd(graph, &cache);
cache.ResetCache();
} }
ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
Graph* graph, FusedAttentionPassCache* cache) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x")) ->NewNode(patterns::PDNodeName(name_scope_, "x"))
...@@ -792,6 +795,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -792,6 +795,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
Graph* g) { Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention pass's fusion"; VLOG(3) << "handle FusedMultiHeadAttention pass's fusion";
int block_id = g->GetBlockId();
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_op_node, pre_layer_norm_op, fused_attention_pattern); pre_layer_norm_op_node, pre_layer_norm_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -833,9 +838,15 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -833,9 +838,15 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern); residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern);
std::string cache_anchor_name = fuse_qkv_matmul_w_node->Var()->Name();
OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block()); OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block());
fused_attention_op_desc.SetType("fused_attention"); fused_attention_op_desc.SetType("fused_attention");
fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()}); fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()});
cache->InsertIntoCache(GenerateCacheKey(cache_anchor_name, "X", block_id),
subgraph.at(x));
fused_attention_op_desc.SetAttr("pre_layer_norm", true); fused_attention_op_desc.SetAttr("pre_layer_norm", true);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node, GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node,
...@@ -860,6 +871,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -860,6 +871,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
{pre_layer_norm_mean_node->Name()}); {pre_layer_norm_mean_node->Name()});
fused_attention_op_desc.SetOutput("LnVariance", fused_attention_op_desc.SetOutput("LnVariance",
{pre_layer_norm_variance_node->Name()}); {pre_layer_norm_variance_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "LnScale", block_id),
pre_layer_norm_scale_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "LnBias", block_id),
pre_layer_norm_bias_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "LnOut", block_id),
pre_layer_norm_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "LnMean", block_id),
pre_layer_norm_mean_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "LnVariance", block_id),
pre_layer_norm_variance_node);
fused_attention_op_desc.SetAttr( fused_attention_op_desc.SetAttr(
"epsilon", "epsilon",
PADDLE_GET_CONST(float, PADDLE_GET_CONST(float,
...@@ -869,8 +895,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -869,8 +895,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
std::vector<int> shape = PADDLE_GET_CONST( std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape")); std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]); fused_attention_op_desc.SetAttr("num_heads", shape[2]);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern); fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node, GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
...@@ -891,6 +915,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -891,6 +915,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
{fuse_qkv_ele_add_out_node->Name()}); {fuse_qkv_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("TransposeOut2", fused_attention_op_desc.SetOutput("TransposeOut2",
{fuse_qkv_transpose_out_node->Name()}); {fuse_qkv_transpose_out_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKVW", block_id),
fuse_qkv_matmul_w_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKVBias", block_id),
fuse_qkv_ele_add_bias_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKVOut", block_id),
fuse_qkv_matmul_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKVBiasOut", block_id),
fuse_qkv_ele_add_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "TransposeOut2", block_id),
fuse_qkv_transpose_out_node);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
qk_matmul_out_node, qk_matmul_out, fused_attention_pattern); qk_matmul_out_node, qk_matmul_out, fused_attention_pattern);
...@@ -911,12 +950,24 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -911,12 +950,24 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern); qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern);
fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()}); fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKOut", block_id),
qk_matmul_out_node);
fused_attention_op_desc.SetInput("SrcMask", fused_attention_op_desc.SetInput("SrcMask",
{add_mask_ele_add_mask_node->Name()}); {add_mask_ele_add_mask_node->Name()});
fused_attention_op_desc.SetOutput("SrcMaskOut", fused_attention_op_desc.SetOutput("SrcMaskOut",
{add_mask_ele_add_out_node->Name()}); {add_mask_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("SoftmaxOut", fused_attention_op_desc.SetOutput("SoftmaxOut",
{qk_softmax_out_node->Name()}); {qk_softmax_out_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "SrcMask", block_id),
add_mask_ele_add_mask_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "SrcMaskOut", block_id),
add_mask_ele_add_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "SoftmaxOut", block_id),
qk_softmax_out_node);
fused_attention_op_desc.SetAttr( fused_attention_op_desc.SetAttr(
"attn_dropout_rate", "attn_dropout_rate",
PADDLE_GET_CONST(float, PADDLE_GET_CONST(float,
...@@ -943,6 +994,18 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -943,6 +994,18 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()}); fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()});
fused_attention_op_desc.SetOutput("FMHAOut", fused_attention_op_desc.SetOutput("FMHAOut",
{qkv_reshape_out_node->Name()}); {qkv_reshape_out_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "AttnDropoutMaskOut", block_id),
attn_dropout_mask_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "AttnDropoutOut", block_id),
attn_dropout_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "QKTVOut", block_id),
qkv_matmul_out_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "FMHAOut", block_id),
qkv_reshape_out_node);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern); out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern);
...@@ -952,15 +1015,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -952,15 +1015,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node, GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node,
out_linear_ele_add_bias, out_linear_ele_add_bias,
fused_attention_pattern); fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_out_node,
out_linear_ele_add_out,
fused_attention_pattern);
fused_attention_op_desc.SetInput("OutLinearW", fused_attention_op_desc.SetInput("OutLinearW",
{out_linear_matmul_w_node->Name()}); {out_linear_matmul_w_node->Name()});
fused_attention_op_desc.SetInput("OutLinearBias", fused_attention_op_desc.SetInput("OutLinearBias",
{out_linear_ele_add_bias_node->Name()}); {out_linear_ele_add_bias_node->Name()});
fused_attention_op_desc.SetOutput("OutLinearOut", fused_attention_op_desc.SetOutput("OutLinearOut",
{out_linear_matmul_out_node->Name()}); {out_linear_matmul_out_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "OutLinearW", block_id),
out_linear_matmul_w_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "OutLinearBias", block_id),
out_linear_ele_add_bias_node);
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "OutLinearOut", block_id),
out_linear_matmul_out_node);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node, GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node,
out_linear_dropout_mask, out_linear_dropout_mask,
fused_attention_pattern); fused_attention_pattern);
...@@ -983,6 +1052,9 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -983,6 +1052,9 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
"dropout_implementation"))); "dropout_implementation")));
fused_attention_op_desc.SetOutput("DropoutMaskOut", fused_attention_op_desc.SetOutput("DropoutMaskOut",
{out_linear_dropout_mask_node->Name()}); {out_linear_dropout_mask_node->Name()});
cache->InsertIntoCache(
GenerateCacheKey(cache_anchor_name, "DropoutMaskOut", block_id),
out_linear_dropout_mask_node);
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node, GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node,
residual_ele_add_out, residual_ele_add_out,
...@@ -1037,6 +1109,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -1037,6 +1109,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
out_linear_ele_add_op_node, out_linear_ele_add_op_node,
out_linear_dropout_op_node, out_linear_dropout_op_node,
residual_ele_add_op_node}); residual_ele_add_op_node});
found_fused_attention++; found_fused_attention++;
}; };
...@@ -1046,7 +1119,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { ...@@ -1046,7 +1119,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
return graph; return graph;
} }
ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
Graph* graph, FusedAttentionPassCache* cache) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x")) ->NewNode(patterns::PDNodeName(name_scope_, "x"))
...@@ -1067,6 +1141,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { ...@@ -1067,6 +1141,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
Graph* g) { Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion"; VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion";
int block_id = g->GetBlockId();
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node, GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node,
residual_ele_add_grad_op, residual_ele_add_grad_op,
fused_attention_grad_pattern); fused_attention_grad_pattern);
...@@ -1124,7 +1200,302 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { ...@@ -1124,7 +1200,302 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
grad_accumulation_sum_op, grad_accumulation_sum_op,
fused_attention_grad_pattern); fused_attention_grad_pattern);
// TODO(Yuang Liu): finish the handler OpDesc fused_attention_grad_op_desc(
residual_ele_add_grad_op_node->Op()->Block());
fused_attention_grad_op_desc.SetType("fused_attention_grad");
fused_attention_grad_op_desc.SetInput("Y@GRAD", {subgraph.at(x)->Name()});
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_w_node,
fuse_qkv_matmul_grad_w,
fused_attention_grad_pattern);
std::string cache_anchor_name = fuse_qkv_matmul_grad_w_node->Var()->Name();
auto* x_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "X", block_id));
auto* attn_dropout_mask_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "AttnDropoutMaskOut", block_id));
auto* attn_dropout_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "AttnDropoutOut", block_id));
auto* dropout_mask_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "DropoutMaskOut", block_id));
auto* fmha_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "FMHAOut", block_id));
auto* ln_bias_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "LnBias", block_id));
auto* ln_mean_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "LnMean", block_id));
auto* ln_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "LnOut", block_id));
auto* ln_scale_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "LnScale", block_id));
auto* ln_variance_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "LnVariance", block_id));
auto* out_linear_bias_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "OutLinearBias", block_id));
auto* out_linear_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "OutLinearOut", block_id));
auto* out_linear_w_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "OutLinearW", block_id));
auto* qk_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKOut", block_id));
auto* qktv_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKTVOut", block_id));
auto* qkv_bias_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKVBias", block_id));
auto* qkv_bias_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKVBiasOut", block_id));
auto* qkv_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKVOut", block_id));
auto* qkv_w_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "QKVW", block_id));
auto* softmax_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "SoftmaxOut", block_id));
auto* src_mask_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "SrcMask", block_id));
auto* src_mask_out_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "SrcMaskOut", block_id));
auto* transpose_out_2_node = cache->GetNodeFromCache(
GenerateCacheKey(cache_anchor_name, "TransposeOut2", block_id));
fused_attention_grad_op_desc.SetInput("X", {x_node->Name()});
fused_attention_grad_op_desc.SetInput("AttnDropoutMaskOut",
{attn_dropout_mask_out_node->Name()});
fused_attention_grad_op_desc.SetInput("AttnDropoutOut",
{attn_dropout_out_node->Name()});
fused_attention_grad_op_desc.SetInput("DropoutMaskOut",
{dropout_mask_out_node->Name()});
fused_attention_grad_op_desc.SetInput("FMHAOut", {fmha_out_node->Name()});
fused_attention_grad_op_desc.SetInput("LnBias", {ln_bias_node->Name()});
fused_attention_grad_op_desc.SetInput("LnMean", {ln_mean_node->Name()});
fused_attention_grad_op_desc.SetInput("LnOut", {ln_out_node->Name()});
fused_attention_grad_op_desc.SetInput("LnScale", {ln_scale_node->Name()});
fused_attention_grad_op_desc.SetInput("LnVariance",
{ln_variance_node->Name()});
fused_attention_grad_op_desc.SetInput("OutLinearBias",
{out_linear_bias_node->Name()});
fused_attention_grad_op_desc.SetInput("OutLinearOut",
{out_linear_out_node->Name()});
fused_attention_grad_op_desc.SetInput("OutLinearW",
{out_linear_w_node->Name()});
fused_attention_grad_op_desc.SetInput("QKOut", {qk_out_node->Name()});
fused_attention_grad_op_desc.SetInput("QKTVOut", {qktv_out_node->Name()});
fused_attention_grad_op_desc.SetInput("QKVBias", {qkv_bias_node->Name()});
fused_attention_grad_op_desc.SetInput("QKVBiasOut",
{qkv_bias_out_node->Name()});
fused_attention_grad_op_desc.SetInput("QKVOut", {qkv_out_node->Name()});
fused_attention_grad_op_desc.SetInput("QKVW", {qkv_w_node->Name()});
fused_attention_grad_op_desc.SetInput("SoftmaxOut",
{softmax_out_node->Name()});
fused_attention_grad_op_desc.SetInput("SrcMask", {src_mask_node->Name()});
fused_attention_grad_op_desc.SetInput("SrcMaskOut",
{src_mask_out_node->Name()});
fused_attention_grad_op_desc.SetInput("TransposeOut2",
{transpose_out_2_node->Name()});
fused_attention_grad_op_desc.SetAttr("add_residual", true);
fused_attention_grad_op_desc.SetAttr(
"attn_dropout_rate",
PADDLE_GET_CONST(
float, attn_dropout_grad_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_grad_op_desc.SetAttr(
"is_test",
PADDLE_GET_CONST(bool,
attn_dropout_grad_op_node->Op()->GetAttr("is_test")));
fused_attention_grad_op_desc.SetAttr(
"attn_dropout_fix_seed",
PADDLE_GET_CONST(bool,
attn_dropout_grad_op_node->Op()->GetAttr("fix_seed")));
fused_attention_grad_op_desc.SetAttr(
"attn_dropout_seed",
PADDLE_GET_CONST(int,
attn_dropout_grad_op_node->Op()->GetAttr("seed")));
fused_attention_grad_op_desc.SetAttr(
"attn_dropout_implementation",
PADDLE_GET_CONST(std::string,
attn_dropout_grad_op_node->Op()->GetAttr(
"dropout_implementation")));
fused_attention_grad_op_desc.SetAttr(
"dropout_rate",
PADDLE_GET_CONST(
float,
out_linear_dropout_grad_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_grad_op_desc.SetAttr(
"dropout_fix_seed",
PADDLE_GET_CONST(
bool, out_linear_dropout_grad_op_node->Op()->GetAttr("fix_seed")));
fused_attention_grad_op_desc.SetAttr(
"dropout_seed",
PADDLE_GET_CONST(
int, out_linear_dropout_grad_op_node->Op()->GetAttr("seed")));
fused_attention_grad_op_desc.SetAttr(
"dropout_implementation",
PADDLE_GET_CONST(std::string,
out_linear_dropout_grad_op_node->Op()->GetAttr(
"dropout_implementation")));
fused_attention_grad_op_desc.SetAttr(
"epsilon",
PADDLE_GET_CONST(
float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon")));
std::vector<int> shape =
PADDLE_GET_CONST(std::vector<int>,
fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape"));
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]);
fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true);
fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true);
// forward op will use default value
// but backward op has to set these redundant attrs
fused_attention_grad_op_desc.SetAttr(
"ln_epsilon",
PADDLE_GET_CONST(
float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon")));
fused_attention_grad_op_desc.SetAttr("ring_id", -1);
GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_x_grad_node,
qkv_matmul_grad_x_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_x_grad_node,
out_linear_matmul_grad_x_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_bias_grad_node,
pre_layer_norm_grad_bias_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_x_grad_node,
fuse_qkv_matmul_grad_x_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_scale_grad_node,
pre_layer_norm_grad_scale_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_bias_grad_node,
out_linear_ele_add_grad_bias_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_x_grad_node,
out_linear_ele_add_grad_x_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_w_grad_node,
out_linear_matmul_grad_w_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qk_scale_grad_out_node,
qk_scale_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qkv_transpose_grad_out_node,
qkv_transpose_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_bias_grad_node,
fuse_qkv_ele_add_grad_bias_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_reshape_grad_out_node,
fuse_qkv_reshape_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_x_grad_node,
fuse_qkv_ele_add_grad_x_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_w_grad_node,
fuse_qkv_matmul_grad_w_grad,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(attn_dropout_grad_out_node,
attn_dropout_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qk_softmax_grad_out_node,
qk_softmax_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_split_grad_out_node,
fuse_qkv_split_grad_out,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(grad_accumulation_out_node,
grad_accumulation_out,
fused_attention_grad_pattern);
fused_attention_grad_op_desc.SetOutput(
"AttnDropoutOut@GRAD", {qkv_matmul_grad_x_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"FMHAOut@GRAD", {out_linear_matmul_grad_x_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"LnBias@GRAD", {pre_layer_norm_grad_bias_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"LnOut@GRAD", {fuse_qkv_matmul_grad_x_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"LnScale@GRAD", {pre_layer_norm_grad_scale_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"OutLinearBias@GRAD", {out_linear_ele_add_grad_bias_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"OutLinearOut@GRAD", {out_linear_ele_add_grad_x_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"OutLinearW@GRAD", {out_linear_matmul_grad_w_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput("QKOut@GRAD",
{qk_scale_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"QKTVOut@GRAD", {qkv_transpose_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"QKVBias@GRAD", {fuse_qkv_ele_add_grad_bias_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"QKVBiasOut@GRAD", {fuse_qkv_reshape_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"QKVOut@GRAD", {fuse_qkv_ele_add_grad_x_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"QKVW@GRAD", {fuse_qkv_matmul_grad_w_grad_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"SoftmaxOut@GRAD", {attn_dropout_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput("SrcMaskOut@GRAD",
{qk_softmax_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"TransposeOut2@GRAD", {fuse_qkv_split_grad_out_node->Name()});
fused_attention_grad_op_desc.SetOutput(
"X@GRAD", {grad_accumulation_out_node->Name()});
auto fused_attention_grad_node =
g->CreateOpNode(&fused_attention_grad_op_desc);
IR_NODE_LINK_TO(fused_attention_grad_node, qkv_matmul_grad_x_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
out_linear_matmul_grad_x_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
pre_layer_norm_grad_bias_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
fuse_qkv_matmul_grad_x_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
pre_layer_norm_grad_scale_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
out_linear_ele_add_grad_bias_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
out_linear_ele_add_grad_x_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
out_linear_matmul_grad_w_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node, qk_scale_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node, qkv_transpose_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
fuse_qkv_ele_add_grad_bias_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node, fuse_qkv_reshape_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
fuse_qkv_ele_add_grad_x_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node,
fuse_qkv_matmul_grad_w_grad_node);
IR_NODE_LINK_TO(fused_attention_grad_node, attn_dropout_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node, qk_softmax_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node, fuse_qkv_split_grad_out_node);
IR_NODE_LINK_TO(fused_attention_grad_node, grad_accumulation_out_node);
IR_NODE_LINK_TO(subgraph.at(x), fused_attention_grad_node);
IR_NODE_LINK_TO(x_node, fused_attention_grad_node);
IR_NODE_LINK_TO(attn_dropout_mask_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(attn_dropout_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(dropout_mask_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(fmha_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(ln_bias_node, fused_attention_grad_node);
IR_NODE_LINK_TO(ln_mean_node, fused_attention_grad_node);
IR_NODE_LINK_TO(ln_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(ln_scale_node, fused_attention_grad_node);
IR_NODE_LINK_TO(ln_variance_node, fused_attention_grad_node);
IR_NODE_LINK_TO(out_linear_bias_node, fused_attention_grad_node);
IR_NODE_LINK_TO(out_linear_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(out_linear_w_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qk_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qktv_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qkv_bias_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qkv_bias_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qkv_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(qkv_w_node, fused_attention_grad_node);
IR_NODE_LINK_TO(softmax_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(src_mask_node, fused_attention_grad_node);
IR_NODE_LINK_TO(src_mask_out_node, fused_attention_grad_node);
IR_NODE_LINK_TO(transpose_out_2_node, fused_attention_grad_node);
GraphSafeRemoveNodes(g, GraphSafeRemoveNodes(g,
{residual_ele_add_grad_op_node, {residual_ele_add_grad_op_node,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase { ...@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
} // namespace patterns } // namespace patterns
class FusedAttentionPassCache {
public:
ir::Node* GetNodeFromCache(const std::string name) {
if (var_name_to_ir_node_cache_.count(name)) {
return var_name_to_ir_node_cache_.find(name)->second;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The key (%d) of FusedAttentionCache does not exist.", name));
}
void InsertIntoCache(const std::string name, ir::Node* node) {
if (!var_name_to_ir_node_cache_.count(name)) {
var_name_to_ir_node_cache_.insert({name, node});
} else {
PADDLE_THROW(platform::errors::AlreadyExists(
"The key (%d) of FusedAttentionCache already exist.", name));
}
}
void ResetCache() { var_name_to_ir_node_cache_.clear(); }
private:
std::unordered_map<std::string, ir::Node*> var_name_to_ir_node_cache_;
};
class FusedAttentionsPass : public FusePassBase { class FusedAttentionsPass : public FusePassBase {
public: public:
virtual ~FusedAttentionsPass() {} virtual ~FusedAttentionsPass() {}
...@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase { ...@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part. // If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it. // If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResFwd(Graph* graph) const; ir::Graph* PreMaskDropResFwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResBwd(Graph* graph) const; const std::string GenerateCacheKey(const std::string anchor,
const std::string var_name,
int block_id) const {
return anchor + "_" + std::to_string(block_id) + "_" + var_name;
}
}; };
} // namespace ir } // namespace ir
......
...@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("BiasDropoutResidualOut", AddOutput("BiasDropoutResidualOut",
"Result of residual + dropout(src + bias).") "Result of residual + dropout(src + bias).")
.AsIntermediate(); .AsIntermediate();
AddOutput("CacheKVOut", "The udpated cache KV."); AddOutput("CacheKVOut", "The udpated cache KV.").AsDispensable();
AddOutput("Y", "Result after attention."); AddOutput("Y", "Result after attention.");
AddAttr<int>("num_heads", "The number head for multi_head_attention.") AddAttr<int>("num_heads", "The number head for multi_head_attention.")
......
...@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
assert ops[2].type == 'fused_attention' assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean' assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad' assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean # two ops for linear, one op for reduce mean
# one fill constant # one fill constant
# one op for reduce mean grad, two ops for linear bwd # one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer # the eighth op should be the optimizer
assert ops[8].type == 'sgd' assert ops[9].type == 'sgd'
exe = paddle.static.Executor()
exe.run(startup_prog)
rst = exe.run(
main_prog,
feed={'x': x_data, 'attn_mask': mask_data},
fetch_list=[loss],
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册