From e7f5d32575ee8c40344c62a921d75e46db0b1575 Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 13 Dec 2021 19:34:39 +0800 Subject: [PATCH] disable bad case for shuffle pass (#38072) * disabled bad case * int to size_t --- .../framework/ir/shuffle_channel_detect_pass.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 1e9598fff87..02e74b7f837 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -101,6 +101,21 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { BOOST_GET_CONST(std::vector, reshape1_desc->GetAttr("shape")); auto reshape2_shape = BOOST_GET_CONST(std::vector, reshape2_desc->GetAttr("shape")); + // shuffle_channel dosen't change shape + auto* block = reshape1_desc->Block(); + if (block) { + auto x_var_name = reshape1_desc->Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + + if (x_shape.size() != reshape2_shape.size()) { + return; + } + + for (size_t i = 0; i < x_shape.size(); i++) { + if (x_shape[i] != reshape2_shape[i]) return; + } + } int i_c = reshape1_shape[2]; int o_c = reshape2_shape[1]; -- GitLab