diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 157fd4d1a4e18fe83e7e74d9b6ddb5970d905d6c..583e45b5742f989b3430bb6a748da43790261c59 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -174,6 +174,91 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { return fc_out; } +SeqConcatFcFusePass::SeqConcatFcFusePass() { + AddOpCompat(OpCompat("sequence_expand")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("ref_level") + .IsNumEQ(0) + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init("seq_concat_fc_fuse", graph); GraphPatternDetector detector; @@ -193,6 +278,10 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "seq_concat_fc_fuse_pass in op compat failed."; + return; + } VLOG(4) << "get one concat pattern"; // fc GET_NODE(fc_w, detector.pattern()); diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h index a70411536455757b49292e990d27e372651b88c9..99dcd4455bc1e90a10fa07ef4e85ecb4ac83b6fb 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h @@ -15,8 +15,6 @@ #pragma once #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { @@ -26,6 +24,7 @@ class Graph; class SeqConcatFcFusePass : public FusePassBase { public: + SeqConcatFcFusePass(); virtual ~SeqConcatFcFusePass() {} protected: