diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index ad426204d5aa18921d037a912fdf756edc361949..031202fb772279dd71c6b07e2da4e9fb5abadc64 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -28,8 +28,8 @@ namespace tensorrt { namespace plugin { template -__global__ void SliceKernel(int num, int dims, const T *input, - const int *offsets_info, T *output) { +__global__ void SliceKernel( + int num, int dims, const T *input, const int *offsets_info, T *output) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; extern __shared__ int shared_data[]; @@ -54,8 +54,10 @@ __global__ void SliceKernel(int num, int dims, const T *input, } } -SlicePlugin::SlicePlugin(std::vector starts, std::vector ends, - std::vector axes, bool with_fp16) +SlicePlugin::SlicePlugin(std::vector starts, + std::vector ends, + std::vector axes, + bool with_fp16) : starts_(starts), ends_(ends), axes_(axes) { with_fp16_ = with_fp16; } @@ -79,10 +81,12 @@ bool SlicePlugin::supportsFormat( nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { if (with_fp16_) { return ((type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kHALF) && + type == nvinfer1::DataType::kHALF || + type == nvinfer1::DataType::kINT32) && (format == nvinfer1::PluginFormat::kLINEAR)); } else { - return ((type == nvinfer1::DataType::kFLOAT) && + return ((type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kINT32) && (format == nvinfer1::PluginFormat::kLINEAR)); } } @@ -99,11 +103,15 @@ nvinfer1::Dims SlicePlugin::getOutputDimensions( return out_dims; } -int SlicePlugin::enqueue(int batch_size, const void *const *inputs, +int SlicePlugin::enqueue(int batch_size, + const void *const *inputs, #if IS_TRT_VERSION_LT(8000) - void **outputs, void *workspace, cudaStream_t stream) { + void **outputs, + void *workspace, + cudaStream_t stream) { #else - void *const *outputs, void *workspace, + void *const *outputs, + void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #endif auto input_dims = getInputDims(0); @@ -153,8 +161,11 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); } - cudaMemcpyAsync(offset_temp_data_, offset_info.data(), - sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(offset_temp_data_, + offset_info.data(), + sizeof(int) * 3 * num_dims, + cudaMemcpyHostToDevice, + stream); int threads = 256; int blocks = (out_num + threads - 1) / threads; @@ -171,9 +182,15 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs, half *output = static_cast(outputs[0]); SliceKernel<<>>( out_num, num_dims, input1, offset_temp_data_, output); + } else if (input_type == nvinfer1::DataType::kINT32) { + VLOG(1) << "TRT Plugin DataType selected. Slice-->int32"; + const int *input1 = static_cast(inputs[0]); + int *output = static_cast(outputs[0]); + SliceKernel<<>>( + out_num, num_dims, input1, offset_temp_data_, output); } else { PADDLE_THROW(platform::errors::Fatal( - "The Slice TRT Plugin's input type should be float or half.")); + "The Slice TRT Plugin's input type should be float, half or int.")); } return cudaGetLastError() != cudaSuccess; } @@ -197,7 +214,8 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { #if IS_TRT_VERSION_GE(6000) SlicePluginDynamic::SlicePluginDynamic(std::vector starts, std::vector ends, - std::vector axes, int decrease_axis, + std::vector axes, + int decrease_axis, bool with_fp16) : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) { with_fp16_ = with_fp16; @@ -238,7 +256,9 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { } nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( - int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { auto in_dims = inputs[0]; nvinfer1::DimsExprs ret = in_dims; @@ -264,7 +284,8 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( for (size_t i = 0; i < in_dims.nbDims; i++) { if (decrease_axis_ == i) continue; res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX, - *expr_builder.constant(0), *ret.d[i]); + *expr_builder.constant(0), + *ret.d[i]); } return res; } @@ -272,26 +293,33 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( } bool SlicePluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, int nb_outputs) TRT_NOEXCEPT { PADDLE_ENFORCE_NOT_NULL( - in_out, platform::errors::InvalidArgument( - "The input of swish plugin shoule not be nullptr.")); + in_out, + platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); PADDLE_ENFORCE_LT( - pos, nb_inputs + nb_outputs, + pos, + nb_inputs + nb_outputs, platform::errors::InvalidArgument("The pos(%d) should be less than the " "num(%d) of the input and the output.", - pos, nb_inputs + nb_outputs)); + pos, + nb_inputs + nb_outputs)); const nvinfer1::PluginTensorDesc &in = in_out[pos]; if (pos == 0) { if (with_fp16_) { return (in.type == nvinfer1::DataType::kFLOAT || - in.type == nvinfer1::DataType::kHALF) && + in.type == nvinfer1::DataType::kHALF || + in.type == nvinfer1::DataType::kINT32) && (in.format == nvinfer1::TensorFormat::kLINEAR); } else { - return (in.type == nvinfer1::DataType::kFLOAT) && + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kINT32) && (in.format == nvinfer1::TensorFormat::kLINEAR); } } @@ -301,24 +329,28 @@ bool SlicePluginDynamic::supportsFormatCombination( } nvinfer1::DataType SlicePluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *input_types, + int index, + const nvinfer1::DataType *input_types, int nb_inputs) const TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(index, 0, + PADDLE_ENFORCE_EQ(index, + 0, platform::errors::InvalidArgument( "The Slice Plugin only has one input, so the " "index value should be 0, but get %d.", index)); PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT || - input_types[0] == nvinfer1::DataType::kHALF), + input_types[0] == nvinfer1::DataType::kHALF || + input_types[0] == nvinfer1::DataType::kINT32), true, platform::errors::InvalidArgument( - "The input type should be half or float")); + "The input type should be half, float or int")); return input_types[0]; } int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, - const void *const *inputs, void *const *outputs, + const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { auto input_dims = input_desc[0].dims; @@ -362,8 +394,11 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); } - cudaMemcpyAsync(offset_temp_data_, offset_info_.data(), - sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(offset_temp_data_, + offset_info_.data(), + sizeof(int) * 3 * num_dims, + cudaMemcpyHostToDevice, + stream); int threads = 256; int blocks = (out_num + threads - 1) / threads; @@ -380,9 +415,15 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, half *output = static_cast(outputs[0]); SliceKernel<<>>( out_num, num_dims, input1, offset_temp_data_, output); + } else if (input_type == nvinfer1::DataType::kINT32) { + VLOG(1) << "TRT Plugin DataType selected. Slice-->int32"; + const int *input1 = static_cast(inputs[0]); + int *output = static_cast(outputs[0]); + SliceKernel<<>>( + out_num, num_dims, input1, offset_temp_data_, output); } else { PADDLE_THROW(platform::errors::Fatal( - "The Slice TRT Plugin's input type should be float or half.")); + "The Slice TRT Plugin's input type should be float, half or int.")); } return cudaGetLastError() != cudaSuccess; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py index a1249c04c27367a3547309b8d4fd3f5fcc731f34..41ad0851784f7f90777f6274cf1129c288eaf352 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py @@ -108,5 +108,56 @@ class StaticSlicePluginTRTTestFp32(SlicePluginTRTTest): self.enable_trt = True +class SlicePluginTRTTestInt32(SlicePluginTRTTest): + + def setUp(self): + self.setUpSliceParams() + self.setUpTensorRTParams() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="int32") + axes = self.params_axes + starts = self.params_starts + ends = self.params_ends + slice_out = fluid.layers.slice(data, + axes=axes, + starts=starts, + ends=ends) + cast_out = fluid.layers.cast(slice_out, 'float32') + out = fluid.layers.batch_norm(cast_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("int32"), + } + self.fetch_list = [out] + + +class StaticSlicePluginTRTTestInt32(SlicePluginTRTTest): + + def setUpTensorRTParams(self): + self.trt_parameters = SlicePluginTRTTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, True, False) + self.enable_trt = True + + def setUp(self): + self.setUpSliceParams() + self.setUpTensorRTParams() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="int32") + axes = self.params_axes + starts = self.params_starts + ends = self.params_ends + slice_out = fluid.layers.slice(data, + axes=axes, + starts=starts, + ends=ends) + cast_out = fluid.layers.cast(slice_out, 'float32') + out = fluid.layers.batch_norm(cast_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("int32"), + } + self.fetch_list = [out] + + if __name__ == "__main__": unittest.main()