From a2b39320375bb90e4208670d1eae79deb7785e15 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Fri, 22 Jul 2022 10:13:26 +0800 Subject: [PATCH] shufflechannelfix (#44516) --- .../tensorrt/convert/shuffle_channel_op.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc index 16264e82cf1..fa4c6f02bcd 100644 --- a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc @@ -46,26 +46,23 @@ class ShuffleChannelOpConverter : public OpConverter { #if IS_TRT_VERSION_GE(8000) if (engine_->with_dynamic_shape()) { auto* input_shape_tensor = Shape(input); + auto* batch_shape_tensor = GetEleTensorOfShape(input_shape_tensor, 0); auto* channel_shape_tensor = GetEleTensorOfShape(input_shape_tensor, 1); auto* group_tensor = Add1DConstantLayer(group, output_name + "_group_tensor_"); auto* new_channel_shape_tensor = Div(channel_shape_tensor, group_tensor); - std::vector shape_dim3{0, 2, 3}; - auto* shape_dim3_tensor = Gather(input_shape_tensor, shape_dim3); + std::vector shape_dim2{2, 3}; + auto* shape_dim2_tensor = Gather(input_shape_tensor, shape_dim2); std::vector itensors; - itensors.push_back(shape_dim3_tensor); + itensors.push_back(batch_shape_tensor); itensors.push_back(group_tensor); itensors.push_back(new_channel_shape_tensor); + itensors.push_back(shape_dim2_tensor); auto* reshape_tensor = Concat(itensors); - auto* reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *reshape_tensor); - nvinfer1::Permutation transpose_new_input{0, 3, 4, 1, 2}; - reshape_layer->setSecondTranspose(transpose_new_input); - auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - layer->setInput(1, *(reshape_layer->getOutput(0))); + layer->setInput(1, *(reshape_tensor)); nvinfer1::Permutation transpose_embed{0, 2, 1, 3, 4}; layer->setSecondTranspose(transpose_embed); auto* output = layer->getOutput(0); -- GitLab