diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 1e9598fff87a8e9504db4f60f08b9fd4160e4a58..02e74b7f837706d7e92339ccb7bd04f4006517e4 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -101,6 +101,21 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { BOOST_GET_CONST(std::vector, reshape1_desc->GetAttr("shape")); auto reshape2_shape = BOOST_GET_CONST(std::vector, reshape2_desc->GetAttr("shape")); + // shuffle_channel dosen't change shape + auto* block = reshape1_desc->Block(); + if (block) { + auto x_var_name = reshape1_desc->Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + + if (x_shape.size() != reshape2_shape.size()) { + return; + } + + for (size_t i = 0; i < x_shape.size(); i++) { + if (x_shape[i] != reshape2_shape[i]) return; + } + } int i_c = reshape1_shape[2]; int o_c = reshape2_shape[1];