From 91dd0f0dd29305cf71d9fe26fe60bb124d03a60d Mon Sep 17 00:00:00 2001 From: wenbin Date: Fri, 28 Jan 2022 11:55:39 +0800 Subject: [PATCH] compile fix (#39272) * slice * shuffle pass enhancement --- paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc | 3 +-- .../inference/tensorrt/plugin/special_slice_plugin.cu | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 63cd4f1f8ef..bcd7bedcc43 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -192,8 +192,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; // trans on channel dim if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; - - if (group != 1) { + if (group != 1 && i_c != 1) { if (trans_axis[1] != 2 && trans_axis[2] != 1) { return; } diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index 3c7b59cec72..ecf06e9bf15 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -128,10 +128,14 @@ int SpecialSlicePluginDynamic::enqueue( auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1) auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1) - assert(input_desc[0].type == nvinfer1::DataType::kHALF); - assert(hidden % 128 == 0); + PADDLE_ENFORCE_EQ( + input_desc[0].type, nvinfer1::DataType::kHALF, + platform::errors::InvalidArgument("Type of input should be half.")); const int32_t hidden = input_dims.d[1]; + PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument( + "hidden should be multiple of 128.")); + constexpr int num_threads = 128; const dim3 blocks(out_dims.d[0], hidden / num_threads); -- GitLab