未验证 提交 8c4c0725 编写于 作者: F feng_shuai 提交者: GitHub

Scale matmul fuse pass (#33803)

* scale_matmul_fuse_pass_init

* enhance scale_matmul_fuse_pass

* change scale_matmul_fuse_pass unittest
上级 f62fce01
......@@ -28,6 +28,45 @@ namespace ir {
class Graph;
using string::PrettyLogDetail;
ScaleMatmulFusePass::ScaleMatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGT(0.0f)
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsNumGT(0.0f)
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End()
.AddAttr("bias_after_scale")
.IsOptional()
.IsType<bool>()
.End();
}
void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
......@@ -43,6 +82,10 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_scale_matmul_fuse_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, scale_matmul_pattern);
......@@ -75,6 +118,11 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->SetInput(matmul_op_input_name,
std::vector<std::string>({scale_in->Name()}));
IR_NODE_LINK_TO(scale_in, matmul_op);
if (!IsCompat(*matmul_op->Op())) {
LOG(WARNING) << "scale_matmul_fuse_pass in out fc op compat failed.";
return;
}
GraphSafeRemoveNodes(graph, {scale_op, scale_out});
found_scale_matmul_fuse_count++;
}
......
......@@ -24,6 +24,7 @@ class Graph;
class ScaleMatmulFusePass : public FusePassBase {
public:
ScaleMatmulFusePass();
virtual ~ScaleMatmulFusePass() {}
protected:
......
......@@ -31,6 +31,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetAttr("scale", scale);
op->SetAttr("bias", bias);
} else if (type == "matmul") {
op->SetAttr("transpose_X", false);
op->SetAttr("transpose_Y", false);
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetAttr("alpha", scale);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册