diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc index e9562235fda412b0d10f4197ee730f03365aac98..7f5630d9c8cb20cf91b27d4f38990b81c5238710 100644 --- a/paddle/fluid/inference/tensorrt/convert/concat_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter { itensors.push_back(engine_->GetITensor(input_name)); } int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); - + if (axis == -1) { + axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions()) + .nbDims - + 1; + } else { + if (!engine_->with_dynamic_shape()) { + axis = axis - 1; // Remove batch dim + } + } auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), itensors.size()); - if (!engine_->with_dynamic_shape()) { - axis = axis - 1; // Remove batch dim - } layer->setAxis(axis); auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 7f270b1f390b7428aa40425ebfb2adb4d02620a8..5eb2de34298772d74edf62f06e1a641948e4bf51 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -72,7 +69,8 @@ class SliceOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { #if IS_TRT_VERSION_GE(6000) - if (engine_->use_oss() && engine_->with_ernie()) { + if (engine_->use_oss() && engine_->with_ernie() && + input_dims.nbDims == 4) { std::vector plugin_inputs; // plugin_inputs.emplace_back(trans_layer->getOutput(0)); plugin_inputs.emplace_back(input); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 2899e1da962c2db644aa7e37a745ada34116538f..1a142b4d92b96b971c65450b0ffb36809e358a83 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -421,10 +421,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } int axis = BOOST_GET_CONST(int, desc.GetAttr("axis")); - if (with_dynamic_shape) { - if (axis < 0) return false; - } else { - if (axis <= 0) return false; + if (!with_dynamic_shape) { + if (axis == 0) return false; } auto concat_inputs = desc.Inputs(); if (concat_inputs.find("AxisTensor") != concat_inputs.end()) { diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index 49c03b761ceb3e9c48dbee31ccd4a74f09b57a1c..ecf06e9bf15139990d5746a11592816ecde9f9f9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -113,32 +113,38 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( template __global__ void SpecialSliceKernel(const T* slice_input, const int32_t* cu_seqlens, T* output) { - const int hidden = blockDim.x; + const int hidden = blockDim.x * gridDim.y; const int batch = blockIdx.x; + const int local_idx = blockIdx.y * blockDim.y + threadIdx.x; - output[batch * hidden + threadIdx.x] = - slice_input[cu_seqlens[batch] * hidden + threadIdx.x]; + output[batch * hidden + local_idx] = + slice_input[cu_seqlens[batch] * hidden + local_idx]; } int SpecialSlicePluginDynamic::enqueue( const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1) - auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1) + auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1) + auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1) - assert(input_desc[0].type == nvinfer1::DataType::kHALF); + PADDLE_ENFORCE_EQ( + input_desc[0].type, nvinfer1::DataType::kHALF, + platform::errors::InvalidArgument("Type of input should be half.")); const int32_t hidden = input_dims.d[1]; - const int num_blocks = out_dims.d[0]; // batch size - const int num_threads = hidden; + PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument( + "hidden should be multiple of 128.")); + + constexpr int num_threads = 128; + const dim3 blocks(out_dims.d[0], hidden / num_threads); const half* slice_input = static_cast(inputs[0]); const int32_t* cu_seqlens = static_cast(inputs[1]); half* output = static_cast(outputs[0]); - SpecialSliceKernel<<>>( - slice_input, cu_seqlens, output); + SpecialSliceKernel<<>>(slice_input, + cu_seqlens, output); return cudaGetLastError() != cudaSuccess; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_concat.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_concat.py index e8a7649fd95fb104aa0bae4648f6a50dc0d7d611..ebd2f7724da22cdcdaf3a30ad6d51afee1fdaf67 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_concat.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_concat.py @@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): def generate_weight1(attrs: List[Dict[str, Any]]): return np.zeros([1]).astype(np.int32) - for dims in [1, 2, 3, 4]: + for dims in [2, 3, 4]: for num_input in [0, 1]: for batch in [1, 2, 4]: for axis in [-1, 0, 1, 2, 3]: @@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): def generate_trt_nodes_num(attrs, dynamic_shape): if dynamic_shape == True: - if attrs[0]['axis'] >= 0: - return 1, 4 - else: - return 0, 5 + return 1, 4 else: - if attrs[0]['axis'] > 0: + if attrs[0]['axis'] != 0: return 1, 4 else: return 0, 5