From 0ee230a7d3177f791d2a5388ab4dffdccc03f4aa Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 21 Feb 2022 16:14:17 +0800 Subject: [PATCH] shuffle_channel pass fix (#39735) --- .../ir/shuffle_channel_detect_pass.cc | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 1e9598fff87..bcd7bedcc43 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { auto* input_node = subgraph.at(x); auto reshape1_desc = reshape1_op->Op(); auto reshape2_desc = reshape2_op->Op(); + auto trans_desc = transpose_op->Op(); std::string input_name = input_node->Name(); std::string output_name = reshape2_out->Name(); @@ -101,10 +102,101 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { BOOST_GET_CONST(std::vector, reshape1_desc->GetAttr("shape")); auto reshape2_shape = BOOST_GET_CONST(std::vector, reshape2_desc->GetAttr("shape")); + auto trans_axis = + BOOST_GET_CONST(std::vector, trans_desc->GetAttr("axis")); + auto* block1 = reshape1_desc->Block(); + auto* block2 = reshape2_desc->Block(); + if (block1 && block2) { + auto x_var_name = reshape1_desc->Input("X")[0]; + auto* x_var_desc = block1->FindVar(x_var_name); + auto x_shape1 = x_var_desc->GetShape(); + x_var_name = reshape2_desc->Input("X")[0]; + x_var_desc = block2->FindVar(x_var_name); + auto x_shape2 = x_var_desc->GetShape(); + // now shuffle_channel is 4D(NCHW) only. + if (x_shape1.size() != 4 || reshape1_shape.size() != 5 || + reshape2_shape.size() != 4 || trans_axis.size() != 5) { + return; + } + + // process 0 and -1 in reshape. + constexpr int64_t copy_dim_val = 0; + for (size_t i = 0; i < reshape1_shape.size(); i++) { + if (reshape1_shape[i] == copy_dim_val) { + reshape1_shape[i] = x_shape1[i]; + } + } + for (size_t i = 0; i < reshape2_shape.size(); i++) { + if (reshape2_shape[i] == copy_dim_val) { + reshape2_shape[i] = x_shape2[i]; + } + } + constexpr int64_t unk_dim_idx = -1; + bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape1_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape1_shape[i] = + std::accumulate(x_shape1.begin(), x_shape1.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape1_shape.begin(), reshape1_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape2_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape2_shape[i] = + std::accumulate(x_shape2.begin(), x_shape2.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape2_shape.begin(), reshape2_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + // shuffle_channel dosen't change shape + if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) { + return; + } + for (size_t i = 1; i < x_shape1.size(); i++) { + if (x_shape1[i] != reshape2_shape[i]) { + return; + } + } + if ((reshape2_shape[3] != reshape1_shape[4]) || + (reshape2_shape[2] != reshape1_shape[3])) { + return; + } + } else { + return; // conservative judgement + } int i_c = reshape1_shape[2]; int o_c = reshape2_shape[1]; int group = o_c / i_c; + // should split on channel dim + if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; + // trans on channel dim + if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; + if (group != 1 && i_c != 1) { + if (trans_axis[1] != 2 && trans_axis[2] != 1) { + return; + } + } framework::OpDesc new_op_desc; new_op_desc.SetType("shuffle_channel"); -- GitLab