未验证 提交 7e8ef328 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass backward op replace. (#50186)

上级 f2ec69b4
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase { ...@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
} // namespace patterns } // namespace patterns
class FusedAttentionPassCache {
public:
ir::Node* GetNodeFromCache(const std::string name) {
if (var_name_to_ir_node_cache_.count(name)) {
return var_name_to_ir_node_cache_.find(name)->second;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The key (%d) of FusedAttentionCache does not exist.", name));
}
void InsertIntoCache(const std::string name, ir::Node* node) {
if (!var_name_to_ir_node_cache_.count(name)) {
var_name_to_ir_node_cache_.insert({name, node});
} else {
PADDLE_THROW(platform::errors::AlreadyExists(
"The key (%d) of FusedAttentionCache already exist.", name));
}
}
void ResetCache() { var_name_to_ir_node_cache_.clear(); }
private:
std::unordered_map<std::string, ir::Node*> var_name_to_ir_node_cache_;
};
class FusedAttentionsPass : public FusePassBase { class FusedAttentionsPass : public FusePassBase {
public: public:
virtual ~FusedAttentionsPass() {} virtual ~FusedAttentionsPass() {}
...@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase { ...@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part. // If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it. // If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResFwd(Graph* graph) const; ir::Graph* PreMaskDropResFwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResBwd(Graph* graph,
FusedAttentionPassCache* cache) const;
ir::Graph* PreMaskDropResBwd(Graph* graph) const; const std::string GenerateCacheKey(const std::string anchor,
const std::string var_name,
int block_id) const {
return anchor + "_" + std::to_string(block_id) + "_" + var_name;
}
}; };
} // namespace ir } // namespace ir
......
...@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("BiasDropoutResidualOut", AddOutput("BiasDropoutResidualOut",
"Result of residual + dropout(src + bias).") "Result of residual + dropout(src + bias).")
.AsIntermediate(); .AsIntermediate();
AddOutput("CacheKVOut", "The udpated cache KV."); AddOutput("CacheKVOut", "The udpated cache KV.").AsDispensable();
AddOutput("Y", "Result after attention."); AddOutput("Y", "Result after attention.");
AddAttr<int>("num_heads", "The number head for multi_head_attention.") AddAttr<int>("num_heads", "The number head for multi_head_attention.")
......
...@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
assert ops[2].type == 'fused_attention' assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean' assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad' assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean # two ops for linear, one op for reduce mean
# one fill constant # one fill constant
# one op for reduce mean grad, two ops for linear bwd # one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer # the eighth op should be the optimizer
assert ops[8].type == 'sgd' assert ops[9].type == 'sgd'
exe = paddle.static.Executor()
exe.run(startup_prog)
rst = exe.run(
main_prog,
feed={'x': x_data, 'attn_mask': mask_data},
fetch_list=[loss],
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册