未验证 提交 0f59d4e6 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786)

上级 7f9b8f06
...@@ -18,16 +18,6 @@ ...@@ -18,16 +18,6 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase { ...@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
class MultiHeadMatmulV2FusePass : public FusePassBase { class MultiHeadMatmulV2FusePass : public FusePassBase {
public: public:
virtual ~MultiHeadMatmulV2FusePass() {} MultiHeadMatmulV2FusePass();
protected: protected:
void ApplyImpl(Graph* graph) const; void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"}; const std::string name_scope_{"multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
}; };
class MultiHeadMatmulV3FusePass : public FusePassBase { class MultiHeadMatmulV3FusePass : public FusePassBase {
public: public:
virtual ~MultiHeadMatmulV3FusePass() {} MultiHeadMatmulV3FusePass();
protected: protected:
void ApplyImpl(Graph* graph) const; void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"}; const std::string name_scope_{"multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv // (reshape_qkv) mul -> mul_qkv
Layers layers; Layers layers;
auto* x = layers.data("x", {128, 768}); auto* x = layers.data("x", {1, 128, 768});
auto out = layers.layer_norm(x); auto out = layers.layer_norm(x);
auto* layer_out = out[0]; auto* layer_out = out[0];
...@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
auto* weights_1 = layers.data("weights1", {768, 768}, true); auto* weights_1 = layers.data("weights1", {768, 768}, true);
auto* weights_2 = layers.data("weights2", {768, 768}, true); auto* weights_2 = layers.data("weights2", {768, 768}, true);
auto* mul_out_0 = layers.mul(layer_out, weights_0); auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2);
auto* mul_out_1 = layers.mul(layer_out, weights_1); auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2);
auto* mul_out_2 = layers.mul(layer_out, weights_2); auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2);
auto* b0 = layers.data("bias_0", {768}, true); auto* b0 = layers.data("bias_0", {768}, true);
auto* b1 = layers.data("bias_1", {768}, true); auto* b1 = layers.data("bias_1", {768}, true);
auto* b2 = layers.data("bias_2", {768}, true); auto* b2 = layers.data("bias_2", {768}, true);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0); auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1); auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2); auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2);
std::vector<int> shape = {128, 12, 64}; std::vector<int> shape = {1, 128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape); auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape); auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape); auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3}; std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis); auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis); auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis); auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false); auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false);
auto* matmul_qk = layers.matmul(scale_0, transpose_1); auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {768}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2); auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768}); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true);
auto* weights_l = layers.data("weightsl", {768, 768}, true); auto* weights_l = layers.data("weightsl", {768, 768}, true);
layers.mul(reshape_qkv_out, weights_l); layers.mul(reshape_qkv_out, weights_l, nullptr, 2);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
......
...@@ -293,13 +293,17 @@ struct Layers { ...@@ -293,13 +293,17 @@ struct Layers {
return outs; return outs;
} }
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) { VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool transpose_x = false, bool transpose_y = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul"); op->SetType("matmul");
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
op->SetAttr("transpose_X", transpose_x);
op->SetAttr("transpose_Y", transpose_y);
op->SetAttr("alpha", 1.0f);
return out; return out;
} }
......
...@@ -23,6 +23,10 @@ def { ...@@ -23,6 +23,10 @@ def {
} }
} }
extra { extra {
attrs {
name: "head_number"
type: INT
}
attrs { attrs {
name: "Scale_out" name: "Scale_out"
type: FLOAT type: FLOAT
......
...@@ -10,12 +10,12 @@ def { ...@@ -10,12 +10,12 @@ def {
name: "axis" name: "axis"
type: INT type: INT
} }
}
extra {
attrs { attrs {
name: "data_format" name: "data_format"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "op_role" name: "op_role"
type: INT type: INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册