From 54af52b0344dadd0956e746779727a80534585b8 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Thu, 1 Jul 2021 10:31:34 +0800 Subject: [PATCH] Shuffle channel detect pass (#33814) --- .../ir/shuffle_channel_detect_pass.cc | 43 ++++++++++++++++++- .../ir/shuffle_channel_detect_pass.h | 1 + 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index b9bd660043..1e9598fff8 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -30,6 +30,44 @@ namespace ir { GET_IR_NODE(reshape2_op); \ GET_IR_NODE(reshape2_out); +ShuffleChannelDetectPass::ShuffleChannelDetectPass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsOptional() + .IsTensor() + .End() + .AddInput("ShapeTensor") + .IsOptional() + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); +} + void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "shufflechannel_pattern"; FusePassBase::Init(pattern_name, graph); @@ -46,7 +84,10 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_NODES; - + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "The Pass in op compat failed."; + return; + } PADDLE_ENFORCE_GT( subgraph.count(x), 0, platform::errors::NotFound("Detector did not find input X.")); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h index d0caba5629..4576cfd865 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h @@ -26,6 +26,7 @@ class Graph; class ShuffleChannelDetectPass : public FusePassBase { public: + ShuffleChannelDetectPass(); virtual ~ShuffleChannelDetectPass() {} protected: -- GitLab