未验证 提交 0f59d4e6 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786)

上级 7f9b8f06
......@@ -18,16 +18,6 @@
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
......@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
class MultiHeadMatmulV2FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV2FusePass() {}
MultiHeadMatmulV2FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
MultiHeadMatmulV3FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
......
......@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv
Layers layers;
auto* x = layers.data("x", {128, 768});
auto* x = layers.data("x", {1, 128, 768});
auto out = layers.layer_norm(x);
auto* layer_out = out[0];
......@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
auto* weights_1 = layers.data("weights1", {768, 768}, true);
auto* weights_2 = layers.data("weights2", {768, 768}, true);
auto* mul_out_0 = layers.mul(layer_out, weights_0);
auto* mul_out_1 = layers.mul(layer_out, weights_1);
auto* mul_out_2 = layers.mul(layer_out, weights_2);
auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2);
auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2);
auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2);
auto* b0 = layers.data("bias_0", {768}, true);
auto* b1 = layers.data("bias_1", {768}, true);
auto* b2 = layers.data("bias_2", {768}, true);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2);
std::vector<int> shape = {128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape);
std::vector<int> shape = {1, 128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis);
auto* transpose_1 = layers.transpose2(reshape_1, axis);
auto* transpose_2 = layers.transpose2(reshape_2, axis);
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false);
auto* matmul_qk = layers.matmul(scale_0, transpose_1);
auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {768}, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3});
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768});
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true);
auto* weights_l = layers.data("weightsl", {768, 768}, true);
layers.mul(reshape_qkv_out, weights_l);
layers.mul(reshape_qkv_out, weights_l, nullptr, 2);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......
......@@ -293,13 +293,17 @@ struct Layers {
return outs;
}
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) {
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool transpose_x = false, bool transpose_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("transpose_X", transpose_x);
op->SetAttr("transpose_Y", transpose_y);
op->SetAttr("alpha", 1.0f);
return out;
}
......
......@@ -23,6 +23,10 @@ def {
}
}
extra {
attrs {
name: "head_number"
type: INT
}
attrs {
name: "Scale_out"
type: FLOAT
......
......@@ -10,12 +10,12 @@ def {
name: "axis"
type: INT
}
}
extra {
attrs {
name: "data_format"
type: STRING
}
}
extra {
attrs {
name: "op_role"
type: INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册