From af9ddeb789642f275ce5d3bccf1f3be94abde277 Mon Sep 17 00:00:00 2001 From: wenbin Date: Thu, 27 Jan 2022 14:10:43 +0800 Subject: [PATCH] fix shuffle_channel_detect_pass (#39242) * shuffle channel pass * add ut * timeout fix * makefile fix --- .../ir/shuffle_channel_detect_pass.cc | 96 ++++++++++++++-- .../tensorrt/convert/shuffle_channel_op.cc | 7 +- paddle/fluid/inference/tensorrt/op_teller.cc | 15 ++- .../unittests/ir/inference/CMakeLists.txt | 1 + .../test_shuffle_channel_detect_pass.py | 107 ++++++++++++++++++ 5 files changed, 210 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_shuffle_channel_detect_pass.py diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 02e74b7f837..63cd4f1f8ef 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { auto* input_node = subgraph.at(x); auto reshape1_desc = reshape1_op->Op(); auto reshape2_desc = reshape2_op->Op(); + auto trans_desc = transpose_op->Op(); std::string input_name = input_node->Name(); std::string output_name = reshape2_out->Name(); @@ -101,25 +102,102 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { BOOST_GET_CONST(std::vector, reshape1_desc->GetAttr("shape")); auto reshape2_shape = BOOST_GET_CONST(std::vector, reshape2_desc->GetAttr("shape")); - // shuffle_channel dosen't change shape - auto* block = reshape1_desc->Block(); - if (block) { + auto trans_axis = + BOOST_GET_CONST(std::vector, trans_desc->GetAttr("axis")); + auto* block1 = reshape1_desc->Block(); + auto* block2 = reshape2_desc->Block(); + if (block1 && block2) { auto x_var_name = reshape1_desc->Input("X")[0]; - auto* x_var_desc = block->FindVar(x_var_name); - const auto x_shape = x_var_desc->GetShape(); - - if (x_shape.size() != reshape2_shape.size()) { + auto* x_var_desc = block1->FindVar(x_var_name); + auto x_shape1 = x_var_desc->GetShape(); + x_var_name = reshape2_desc->Input("X")[0]; + x_var_desc = block2->FindVar(x_var_name); + auto x_shape2 = x_var_desc->GetShape(); + // now shuffle_channel is 4D(NCHW) only. + if (x_shape1.size() != 4 || reshape1_shape.size() != 5 || + reshape2_shape.size() != 4 || trans_axis.size() != 5) { return; } - for (size_t i = 0; i < x_shape.size(); i++) { - if (x_shape[i] != reshape2_shape[i]) return; + // process 0 and -1 in reshape. + constexpr int64_t copy_dim_val = 0; + for (size_t i = 0; i < reshape1_shape.size(); i++) { + if (reshape1_shape[i] == copy_dim_val) { + reshape1_shape[i] = x_shape1[i]; + } + } + for (size_t i = 0; i < reshape2_shape.size(); i++) { + if (reshape2_shape[i] == copy_dim_val) { + reshape2_shape[i] = x_shape2[i]; + } + } + constexpr int64_t unk_dim_idx = -1; + bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape1_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape1_shape[i] = + std::accumulate(x_shape1.begin(), x_shape1.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape1_shape.begin(), reshape1_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape2_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape2_shape[i] = + std::accumulate(x_shape2.begin(), x_shape2.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape2_shape.begin(), reshape2_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + // shuffle_channel dosen't change shape + if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) { + return; } + for (size_t i = 1; i < x_shape1.size(); i++) { + if (x_shape1[i] != reshape2_shape[i]) { + return; + } + } + if ((reshape2_shape[3] != reshape1_shape[4]) || + (reshape2_shape[2] != reshape1_shape[3])) { + return; + } + } else { + return; // conservative judgement } int i_c = reshape1_shape[2]; int o_c = reshape2_shape[1]; int group = o_c / i_c; + // should split on channel dim + if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; + // trans on channel dim + if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; + + if (group != 1) { + if (trans_axis[1] != 2 && trans_axis[2] != 1) { + return; + } + } framework::OpDesc new_op_desc; new_op_desc.SetType("shuffle_channel"); diff --git a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc index 976fe9502ac..e6422522e50 100644 --- a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc @@ -39,12 +39,7 @@ class ShuffleChannelOpConverter : public OpConverter { // Declare inputs auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto input_dims = input->getDimensions(); - PADDLE_ENFORCE_EQ( - input_dims.nbDims, 3, - platform::errors::InvalidArgument("ShuffleChannel TRT op converter " - "input dims is invalid. The input " - "dims size should be 3, but got %d.", - input_dims.nbDims)); + int c = input_dims.d[0]; int h = input_dims.d[1]; int w = input_dims.d[2]; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 558579c6253..5e320a02702 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1295,6 +1295,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, "the shuffle_channel op does not support dynamic shape yet"; return false; } + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto* input_desc = block->FindVar(desc.Input("X").front()); + const auto input_shape = input_desc->GetShape(); + if (input_shape.size() != 4) { + VLOG(3) << "input dims is invalid. The input " + "dims size should be 4."; + return false; + } } if (op_type == "skip_layernorm") { @@ -1606,7 +1620,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } - VLOG(3) << "trt unsupported op " << op_type; return false; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 11abb2623bb..e3680104251 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -102,6 +102,7 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_shuffle_channel_detect_pass PROPERTIES TIMEOUT 120) if (WIN32) set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_shuffle_channel_detect_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_shuffle_channel_detect_pass.py new file mode 100644 index 00000000000..a864e2fe5a1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_shuffle_channel_detect_pass.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestShuffleChannelDetectPass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if attrs[0]['input_shape'] != attrs[2]['shape']: + return False + + return True + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + out_channel = draw(st.integers(min_value=1, max_value=16)) + group = draw(st.integers(min_value=1, max_value=4)) + in_channel = group * out_channel + x_shape = [batch_size, in_channel, 64, 64] + shape = [0, group, out_channel, -1, 64] + axis_v = [0, 2, 1, 3, 4] + + def generate_reshape2_Input(): + return np.random.random(x_shape).astype(np.float32) + + reshape2_op1 = OpConfig( + "reshape2", + inputs={"X": ["reshape2_input1"], }, + outputs={ + "Out": ["reshape2_output1"], + "XShape": ["reshape2_xshape1"] + }, + shape=shape, + input_shape=x_shape) + transpose2_op = OpConfig( + "transpose2", + inputs={"X": ["reshape2_output1"], }, + outputs={ + "Out": ["transpose2_ouput"], + "XShape": ["transpose2_xshape"] + }, + axis=axis_v) + reshape2_op2 = OpConfig( + "reshape2", + inputs={"X": ["transpose2_ouput"], }, + outputs={ + "Out": ["reshape2_output2"], + "XShape": ["reshape2_xshape2"] + }, + shape=x_shape) + ops = [reshape2_op1, transpose2_op, reshape2_op2] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "reshape2_input1": + TensorConfig(data_gen=partial(generate_reshape2_Input)), + }, + weights={}, + outputs=["reshape2_output2"]) + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=4, + min_subgraph_size=1, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['shuffle_channel'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, + passes=["shuffle_channel_detect_pass"], ) + + +if __name__ == "__main__": + unittest.main() -- GitLab