diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index 50d6b97bbea8ef5508f8bfaa8f84717cecb375f4..523c2161326466eac21e89d9b5442c16138e967a 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -19,7 +19,50 @@ namespace paddle { namespace framework { namespace ir { -void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { +TransposeFlattenConcatFusePass::TransposeFlattenConcatFusePass() { + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + AddOpCompat(OpCompat("flatten2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, 1}) + .End(); +} + +void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse( + ir::Graph *graph, int times) const { const std::string pattern_name = "transpose_flatten" + std::to_string(times) + "_concat_fuse"; @@ -37,6 +80,10 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } const int kNumFields = 5; const int kTransOffset = 1; const int kTransOutOffset = 2; diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h index 939a8c31e5501e23968f9b44b4fe09e78280fd07..7c3ef2986e27e0656b3722bc5cb1c77d98190d62 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -28,10 +27,14 @@ namespace ir { // structure. class TransposeFlattenConcatFusePass : public FusePassBase { public: + TransposeFlattenConcatFusePass(); virtual ~TransposeFlattenConcatFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; + + private: + void RunTransposeFlattenConcatFuse(ir::Graph* graph, int times) const; }; } // namespace ir