From 72af57bb840b96ea52f5378f699cde711e73a152 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Jul 2021 20:08:27 +0800 Subject: [PATCH] [pass_enhance] : seq_concat_fc_fuse_pass (#33961) --- .../framework/ir/seq_concat_fc_fuse_pass.cc | 89 +++++++++++++++++++ .../framework/ir/seq_concat_fc_fuse_pass.h | 3 +- 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 157fd4d1a4..583e45b574 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -174,6 +174,91 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { return fc_out; } +SeqConcatFcFusePass::SeqConcatFcFusePass() { + AddOpCompat(OpCompat("sequence_expand")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("ref_level") + .IsNumEQ(0) + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init("seq_concat_fc_fuse", graph); GraphPatternDetector detector; @@ -193,6 +278,10 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "seq_concat_fc_fuse_pass in op compat failed."; + return; + } VLOG(4) << "get one concat pattern"; // fc GET_NODE(fc_w, detector.pattern()); diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h index a704115364..99dcd4455b 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h @@ -15,8 +15,6 @@ #pragma once #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { @@ -26,6 +24,7 @@ class Graph; class SeqConcatFcFusePass : public FusePassBase { public: + SeqConcatFcFusePass(); virtual ~SeqConcatFcFusePass() {} protected: -- GitLab