diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index b9bd660043bf1b0d24cf302bf782ec179245ff6a..1e9598fff87a8e9504db4f60f08b9fd4160e4a58 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -30,6 +30,44 @@ namespace ir { GET_IR_NODE(reshape2_op); \ GET_IR_NODE(reshape2_out); +ShuffleChannelDetectPass::ShuffleChannelDetectPass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsOptional() + .IsTensor() + .End() + .AddInput("ShapeTensor") + .IsOptional() + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); +} + void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "shufflechannel_pattern"; FusePassBase::Init(pattern_name, graph); @@ -46,7 +84,10 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_NODES; - + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "The Pass in op compat failed."; + return; + } PADDLE_ENFORCE_GT( subgraph.count(x), 0, platform::errors::NotFound("Detector did not find input X.")); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h index d0caba5629f00384694c7aa289db734d4ab74253..4576cfd865bb3392ea01ff22bb521c7a2005c275 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h @@ -26,6 +26,7 @@ class Graph; class ShuffleChannelDetectPass : public FusePassBase { public: + ShuffleChannelDetectPass(); virtual ~ShuffleChannelDetectPass() {} protected: