diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 63cd4f1f8ef4860663e3490b522201972a9519e5..bcd7bedcc43a66564f5777cd139860bd546229e2 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -192,8 +192,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; // trans on channel dim if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; - - if (group != 1) { + if (group != 1 && i_c != 1) { if (trans_axis[1] != 2 && trans_axis[2] != 1) { return; } diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index 3c7b59cec72b17594ef629935ad65a8aa08f72aa..ecf06e9bf15139990d5746a11592816ecde9f9f9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -128,10 +128,14 @@ int SpecialSlicePluginDynamic::enqueue( auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1) auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1) - assert(input_desc[0].type == nvinfer1::DataType::kHALF); - assert(hidden % 128 == 0); + PADDLE_ENFORCE_EQ( + input_desc[0].type, nvinfer1::DataType::kHALF, + platform::errors::InvalidArgument("Type of input should be half.")); const int32_t hidden = input_dims.d[1]; + PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument( + "hidden should be multiple of 128.")); + constexpr int num_threads = 128; const dim3 blocks(out_dims.d[0], hidden / num_threads);